From 5b62bef8cbf73f910513ef3b1f557aa94b384854 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 19 Aug 2015 13:17:26 -0700 Subject: [PATCH 001/802] [SPARK-8918] [MLLIB] [DOC] Add @since tags to mllib.clustering This continues the work from #8256. I removed `since` tags from private/protected/local methods/variables (see https://github.com/apache/spark/commit/72fdeb64630470f6f46cf3eed8ffbfe83a7c4659). MechCoder Closes #8256 Author: Xiangrui Meng Author: Xiaoqing Wang Author: MechCoder Closes #8288 from mengxr/SPARK-8918. --- .../mllib/clustering/GaussianMixture.scala | 56 +++++++++++---- .../clustering/GaussianMixtureModel.scala | 32 +++++++-- .../spark/mllib/clustering/KMeans.scala | 36 +++++++++- .../spark/mllib/clustering/KMeansModel.scala | 37 ++++++++-- .../apache/spark/mllib/clustering/LDA.scala | 71 ++++++++++++++++--- .../spark/mllib/clustering/LDAModel.scala | 64 +++++++++++++++-- .../spark/mllib/clustering/LDAOptimizer.scala | 12 +++- .../clustering/PowerIterationClustering.scala | 29 +++++++- .../mllib/clustering/StreamingKMeans.scala | 53 +++++++++++--- 9 files changed, 338 insertions(+), 52 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index e459367333d26..bc27b1fe7390b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -62,6 +62,7 @@ class GaussianMixture private ( /** * Constructs a default instance. The default parameters are {k: 2, convergenceTol: 0.01, * maxIterations: 100, seed: random}. + * @since 1.3.0 */ def this() = this(2, 0.01, 100, Utils.random.nextLong()) @@ -72,9 +73,11 @@ class GaussianMixture private ( // default random starting point private var initialModel: Option[GaussianMixtureModel] = None - /** Set the initial GMM starting point, bypassing the random initialization. - * You must call setK() prior to calling this method, and the condition - * (model.k == this.k) must be met; failure will result in an IllegalArgumentException + /** + * Set the initial GMM starting point, bypassing the random initialization. + * You must call setK() prior to calling this method, and the condition + * (model.k == this.k) must be met; failure will result in an IllegalArgumentException + * @since 1.3.0 */ def setInitialModel(model: GaussianMixtureModel): this.type = { if (model.k == k) { @@ -85,30 +88,46 @@ class GaussianMixture private ( this } - /** Return the user supplied initial GMM, if supplied */ + /** + * Return the user supplied initial GMM, if supplied + * @since 1.3.0 + */ def getInitialModel: Option[GaussianMixtureModel] = initialModel - /** Set the number of Gaussians in the mixture model. Default: 2 */ + /** + * Set the number of Gaussians in the mixture model. Default: 2 + * @since 1.3.0 + */ def setK(k: Int): this.type = { this.k = k this } - /** Return the number of Gaussians in the mixture model */ + /** + * Return the number of Gaussians in the mixture model + * @since 1.3.0 + */ def getK: Int = k - /** Set the maximum number of iterations to run. Default: 100 */ + /** + * Set the maximum number of iterations to run. Default: 100 + * @since 1.3.0 + */ def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this } - /** Return the maximum number of iterations to run */ + /** + * Return the maximum number of iterations to run + * @since 1.3.0 + */ def getMaxIterations: Int = maxIterations /** * Set the largest change in log-likelihood at which convergence is * considered to have occurred. + * @since 1.3.0 */ def setConvergenceTol(convergenceTol: Double): this.type = { this.convergenceTol = convergenceTol @@ -118,19 +137,29 @@ class GaussianMixture private ( /** * Return the largest change in log-likelihood at which convergence is * considered to have occurred. + * @since 1.3.0 */ def getConvergenceTol: Double = convergenceTol - /** Set the random seed */ + /** + * Set the random seed + * @since 1.3.0 + */ def setSeed(seed: Long): this.type = { this.seed = seed this } - /** Return the random seed */ + /** + * Return the random seed + * @since 1.3.0 + */ def getSeed: Long = seed - /** Perform expectation maximization */ + /** + * Perform expectation maximization + * @since 1.3.0 + */ def run(data: RDD[Vector]): GaussianMixtureModel = { val sc = data.sparkContext @@ -204,7 +233,10 @@ class GaussianMixture private ( new GaussianMixtureModel(weights, gaussians) } - /** Java-friendly version of [[run()]] */ + /** + * Java-friendly version of [[run()]] + * @since 1.3.0 + */ def run(data: JavaRDD[Vector]): GaussianMixtureModel = run(data.rdd) private def updateWeightsAndGaussians( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 76aeebd703d4e..2fa0473737aae 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -43,6 +43,7 @@ import org.apache.spark.sql.{SQLContext, Row} * the weight for Gaussian i, and weights.sum == 1 * @param gaussians Array of MultivariateGaussian where gaussians(i) represents * the Multivariate Gaussian (Normal) Distribution for Gaussian i + * @since 1.3.0 */ @Experimental class GaussianMixtureModel( @@ -53,32 +54,48 @@ class GaussianMixtureModel( override protected def formatVersion = "1.0" + /** + * @since 1.4.0 + */ override def save(sc: SparkContext, path: String): Unit = { GaussianMixtureModel.SaveLoadV1_0.save(sc, path, weights, gaussians) } - /** Number of gaussians in mixture */ + /** + * Number of gaussians in mixture + * @since 1.3.0 + */ def k: Int = weights.length - /** Maps given points to their cluster indices. */ + /** + * Maps given points to their cluster indices. + * @since 1.3.0 + */ def predict(points: RDD[Vector]): RDD[Int] = { val responsibilityMatrix = predictSoft(points) responsibilityMatrix.map(r => r.indexOf(r.max)) } - /** Maps given point to its cluster index. */ + /** + * Maps given point to its cluster index. + * @since 1.5.0 + */ def predict(point: Vector): Int = { val r = computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k) r.indexOf(r.max) } - /** Java-friendly version of [[predict()]] */ + /** + * Java-friendly version of [[predict()]] + * @since 1.4.0 + */ def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]] /** * Given the input vectors, return the membership value of each vector * to all mixture components. + * @since 1.3.0 */ def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = { val sc = points.sparkContext @@ -91,6 +108,7 @@ class GaussianMixtureModel( /** * Given the input vector, return the membership values to all mixture components. + * @since 1.4.0 */ def predictSoft(point: Vector): Array[Double] = { computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k) @@ -115,6 +133,9 @@ class GaussianMixtureModel( } } +/** + * @since 1.4.0 + */ @Experimental object GaussianMixtureModel extends Loader[GaussianMixtureModel] { @@ -165,6 +186,9 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { } } + /** + * @since 1.4.0 + */ override def load(sc: SparkContext, path: String) : GaussianMixtureModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) implicit val formats = DefaultFormats diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 0a65403f4ec95..9ef6834e5ea8d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -49,15 +49,20 @@ class KMeans private ( /** * Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1, * initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4, seed: random}. + * @since 0.8.0 */ def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4, Utils.random.nextLong()) /** * Number of clusters to create (k). + * @since 1.4.0 */ def getK: Int = k - /** Set the number of clusters to create (k). Default: 2. */ + /** + * Set the number of clusters to create (k). Default: 2. + * @since 0.8.0 + */ def setK(k: Int): this.type = { this.k = k this @@ -65,10 +70,14 @@ class KMeans private ( /** * Maximum number of iterations to run. + * @since 1.4.0 */ def getMaxIterations: Int = maxIterations - /** Set maximum number of iterations to run. Default: 20. */ + /** + * Set maximum number of iterations to run. Default: 20. + * @since 0.8.0 + */ def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this @@ -76,6 +85,7 @@ class KMeans private ( /** * The initialization algorithm. This can be either "random" or "k-means||". + * @since 1.4.0 */ def getInitializationMode: String = initializationMode @@ -83,6 +93,7 @@ class KMeans private ( * Set the initialization algorithm. This can be either "random" to choose random points as * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||. + * @since 0.8.0 */ def setInitializationMode(initializationMode: String): this.type = { KMeans.validateInitMode(initializationMode) @@ -93,6 +104,7 @@ class KMeans private ( /** * :: Experimental :: * Number of runs of the algorithm to execute in parallel. + * @since 1.4.0 */ @Experimental def getRuns: Int = runs @@ -102,6 +114,7 @@ class KMeans private ( * Set the number of runs of the algorithm to execute in parallel. We initialize the algorithm * this many times with random starting conditions (configured by the initialization mode), then * return the best clustering found over any run. Default: 1. + * @since 0.8.0 */ @Experimental def setRuns(runs: Int): this.type = { @@ -114,12 +127,14 @@ class KMeans private ( /** * Number of steps for the k-means|| initialization mode + * @since 1.4.0 */ def getInitializationSteps: Int = initializationSteps /** * Set the number of steps for the k-means|| initialization mode. This is an advanced * setting -- the default of 5 is almost always enough. Default: 5. + * @since 0.8.0 */ def setInitializationSteps(initializationSteps: Int): this.type = { if (initializationSteps <= 0) { @@ -131,12 +146,14 @@ class KMeans private ( /** * The distance threshold within which we've consider centers to have converged. + * @since 1.4.0 */ def getEpsilon: Double = epsilon /** * Set the distance threshold within which we've consider centers to have converged. * If all centers move less than this Euclidean distance, we stop iterating one run. + * @since 0.8.0 */ def setEpsilon(epsilon: Double): this.type = { this.epsilon = epsilon @@ -145,10 +162,14 @@ class KMeans private ( /** * The random seed for cluster initialization. + * @since 1.4.0 */ def getSeed: Long = seed - /** Set the random seed for cluster initialization. */ + /** + * Set the random seed for cluster initialization. + * @since 1.4.0 + */ def setSeed(seed: Long): this.type = { this.seed = seed this @@ -162,6 +183,7 @@ class KMeans private ( * Set the initial starting point, bypassing the random initialization or k-means|| * The condition model.k == this.k must be met, failure results * in an IllegalArgumentException. + * @since 1.4.0 */ def setInitialModel(model: KMeansModel): this.type = { require(model.k == k, "mismatched cluster count") @@ -172,6 +194,7 @@ class KMeans private ( /** * Train a K-means model on the given set of points; `data` should be cached for high * performance, because this is an iterative algorithm. + * @since 0.8.0 */ def run(data: RDD[Vector]): KMeansModel = { @@ -430,11 +453,14 @@ class KMeans private ( /** * Top-level methods for calling K-means clustering. + * @since 0.8.0 */ object KMeans { // Initialization mode names + /** @since 0.8.0 */ val RANDOM = "random" + /** @since 0.8.0 */ val K_MEANS_PARALLEL = "k-means||" /** @@ -446,6 +472,7 @@ object KMeans { * @param runs number of parallel runs, defaults to 1. The best model is returned. * @param initializationMode initialization model, either "random" or "k-means||" (default). * @param seed random seed value for cluster initialization + * @since 1.3.0 */ def train( data: RDD[Vector], @@ -470,6 +497,7 @@ object KMeans { * @param maxIterations max number of iterations * @param runs number of parallel runs, defaults to 1. The best model is returned. * @param initializationMode initialization model, either "random" or "k-means||" (default). + * @since 0.8.0 */ def train( data: RDD[Vector], @@ -486,6 +514,7 @@ object KMeans { /** * Trains a k-means model using specified parameters and the default values for unspecified. + * @since 0.8.0 */ def train( data: RDD[Vector], @@ -496,6 +525,7 @@ object KMeans { /** * Trains a k-means model using specified parameters and the default values for unspecified. + * @since 0.8.0 */ def train( data: RDD[Vector], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 96359024fa228..8de2087ceb4df 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -34,35 +34,52 @@ import org.apache.spark.sql.Row /** * A clustering model for K-means. Each point belongs to the cluster with the closest center. + * @since 0.8.0 */ class KMeansModel ( val clusterCenters: Array[Vector]) extends Saveable with Serializable with PMMLExportable { - /** A Java-friendly constructor that takes an Iterable of Vectors. */ + /** + * A Java-friendly constructor that takes an Iterable of Vectors. + * @since 1.4.0 + */ def this(centers: java.lang.Iterable[Vector]) = this(centers.asScala.toArray) - /** Total number of clusters. */ + /** + * Total number of clusters. + * @since 0.8.0 + */ def k: Int = clusterCenters.length - /** Returns the cluster index that a given point belongs to. */ + /** + * Returns the cluster index that a given point belongs to. + * @since 0.8.0 + */ def predict(point: Vector): Int = { KMeans.findClosest(clusterCentersWithNorm, new VectorWithNorm(point))._1 } - /** Maps given points to their cluster indices. */ + /** + * Maps given points to their cluster indices. + * @since 1.0.0 + */ def predict(points: RDD[Vector]): RDD[Int] = { val centersWithNorm = clusterCentersWithNorm val bcCentersWithNorm = points.context.broadcast(centersWithNorm) points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1) } - /** Maps given points to their cluster indices. */ + /** + * Maps given points to their cluster indices. + * @since 1.0.0 + */ def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]] /** * Return the K-means cost (sum of squared distances of points to their nearest center) for this * model on the given data. + * @since 0.8.0 */ def computeCost(data: RDD[Vector]): Double = { val centersWithNorm = clusterCentersWithNorm @@ -73,6 +90,9 @@ class KMeansModel ( private def clusterCentersWithNorm: Iterable[VectorWithNorm] = clusterCenters.map(new VectorWithNorm(_)) + /** + * @since 1.4.0 + */ override def save(sc: SparkContext, path: String): Unit = { KMeansModel.SaveLoadV1_0.save(sc, this, path) } @@ -80,7 +100,14 @@ class KMeansModel ( override protected def formatVersion: String = "1.0" } +/** + * @since 1.4.0 + */ object KMeansModel extends Loader[KMeansModel] { + + /** + * @since 1.4.0 + */ override def load(sc: SparkContext, path: String): KMeansModel = { KMeansModel.SaveLoadV1_0.load(sc, path) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 0fc9b1ac4d716..2a8c6acbaec61 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -43,6 +43,7 @@ import org.apache.spark.util.Utils * * @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation * (Wikipedia)]] + * @since 1.3.0 */ @Experimental class LDA private ( @@ -54,18 +55,25 @@ class LDA private ( private var checkpointInterval: Int, private var ldaOptimizer: LDAOptimizer) extends Logging { + /** + * Constructs a LDA instance with default parameters. + * @since 1.3.0 + */ def this() = this(k = 10, maxIterations = 20, docConcentration = Vectors.dense(-1), topicConcentration = -1, seed = Utils.random.nextLong(), checkpointInterval = 10, ldaOptimizer = new EMLDAOptimizer) /** * Number of topics to infer. I.e., the number of soft cluster centers. + * + * @since 1.3.0 */ def getK: Int = k /** * Number of topics to infer. I.e., the number of soft cluster centers. * (default = 10) + * @since 1.3.0 */ def setK(k: Int): this.type = { require(k > 0, s"LDA k (number of clusters) must be > 0, but was set to $k") @@ -78,6 +86,7 @@ class LDA private ( * distributions over topics ("theta"). * * This is the parameter to a Dirichlet distribution. + * @since 1.5.0 */ def getAsymmetricDocConcentration: Vector = this.docConcentration @@ -87,6 +96,7 @@ class LDA private ( * * This method assumes the Dirichlet distribution is symmetric and can be described by a single * [[Double]] parameter. It should fail if docConcentration is asymmetric. + * @since 1.3.0 */ def getDocConcentration: Double = { val parameter = docConcentration(0) @@ -121,6 +131,7 @@ class LDA private ( * - Values should be >= 0 * - default = uniformly (1.0 / k), following the implementation from * [[https://github.com/Blei-Lab/onlineldavb]]. + * @since 1.5.0 */ def setDocConcentration(docConcentration: Vector): this.type = { require(docConcentration.size > 0, "docConcentration must have > 0 elements") @@ -128,22 +139,37 @@ class LDA private ( this } - /** Replicates a [[Double]] docConcentration to create a symmetric prior. */ + /** + * Replicates a [[Double]] docConcentration to create a symmetric prior. + * @since 1.3.0 + */ def setDocConcentration(docConcentration: Double): this.type = { this.docConcentration = Vectors.dense(docConcentration) this } - /** Alias for [[getAsymmetricDocConcentration]] */ + /** + * Alias for [[getAsymmetricDocConcentration]] + * @since 1.5.0 + */ def getAsymmetricAlpha: Vector = getAsymmetricDocConcentration - /** Alias for [[getDocConcentration]] */ + /** + * Alias for [[getDocConcentration]] + * @since 1.3.0 + */ def getAlpha: Double = getDocConcentration - /** Alias for [[setDocConcentration()]] */ + /** + * Alias for [[setDocConcentration()]] + * @since 1.5.0 + */ def setAlpha(alpha: Vector): this.type = setDocConcentration(alpha) - /** Alias for [[setDocConcentration()]] */ + /** + * Alias for [[setDocConcentration()]] + * @since 1.3.0 + */ def setAlpha(alpha: Double): this.type = setDocConcentration(alpha) /** @@ -154,6 +180,7 @@ class LDA private ( * * Note: The topics' distributions over terms are called "beta" in the original LDA paper * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. + * @since 1.3.0 */ def getTopicConcentration: Double = this.topicConcentration @@ -178,36 +205,51 @@ class LDA private ( * - Value should be >= 0 * - default = (1.0 / k), following the implementation from * [[https://github.com/Blei-Lab/onlineldavb]]. + * @since 1.3.0 */ def setTopicConcentration(topicConcentration: Double): this.type = { this.topicConcentration = topicConcentration this } - /** Alias for [[getTopicConcentration]] */ + /** + * Alias for [[getTopicConcentration]] + * @since 1.3.0 + */ def getBeta: Double = getTopicConcentration - /** Alias for [[setTopicConcentration()]] */ + /** + * Alias for [[setTopicConcentration()]] + * @since 1.3.0 + */ def setBeta(beta: Double): this.type = setTopicConcentration(beta) /** * Maximum number of iterations for learning. + * @since 1.3.0 */ def getMaxIterations: Int = maxIterations /** * Maximum number of iterations for learning. * (default = 20) + * @since 1.3.0 */ def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this } - /** Random seed */ + /** + * Random seed + * @since 1.3.0 + */ def getSeed: Long = seed - /** Random seed */ + /** + * Random seed + * @since 1.3.0 + */ def setSeed(seed: Long): this.type = { this.seed = seed this @@ -215,6 +257,7 @@ class LDA private ( /** * Period (in iterations) between checkpoints. + * @since 1.3.0 */ def getCheckpointInterval: Int = checkpointInterval @@ -225,6 +268,7 @@ class LDA private ( * [[org.apache.spark.SparkContext]], this setting is ignored. * * @see [[org.apache.spark.SparkContext#setCheckpointDir]] + * @since 1.3.0 */ def setCheckpointInterval(checkpointInterval: Int): this.type = { this.checkpointInterval = checkpointInterval @@ -236,6 +280,7 @@ class LDA private ( * :: DeveloperApi :: * * LDAOptimizer used to perform the actual calculation + * @since 1.4.0 */ @DeveloperApi def getOptimizer: LDAOptimizer = ldaOptimizer @@ -244,6 +289,7 @@ class LDA private ( * :: DeveloperApi :: * * LDAOptimizer used to perform the actual calculation (default = EMLDAOptimizer) + * @since 1.4.0 */ @DeveloperApi def setOptimizer(optimizer: LDAOptimizer): this.type = { @@ -254,6 +300,7 @@ class LDA private ( /** * Set the LDAOptimizer used to perform the actual calculation by algorithm name. * Currently "em", "online" are supported. + * @since 1.4.0 */ def setOptimizer(optimizerName: String): this.type = { this.ldaOptimizer = @@ -274,6 +321,7 @@ class LDA private ( * (where the vocabulary size is the length of the vector). * Document IDs must be unique and >= 0. * @return Inferred LDA model + * @since 1.3.0 */ def run(documents: RDD[(Long, Vector)]): LDAModel = { val state = ldaOptimizer.initialize(documents, this) @@ -289,7 +337,10 @@ class LDA private ( state.getLDAModel(iterationTimes) } - /** Java-friendly version of [[run()]] */ + /** + * Java-friendly version of [[run()]] + * @since 1.3.0 + */ def run(documents: JavaPairRDD[java.lang.Long, Vector]): LDAModel = { run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 82f05e4a18cee..b70e380c0393e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -192,12 +192,24 @@ class LocalLDAModel private[clustering] ( override protected[clustering] val gammaShape: Double = 100) extends LDAModel with Serializable { + /** + * @since 1.3.0 + */ override def k: Int = topics.numCols + /** + * @since 1.3.0 + */ override def vocabSize: Int = topics.numRows + /** + * @since 1.3.0 + */ override def topicsMatrix: Matrix = topics + /** + * @since 1.3.0 + */ override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = { val brzTopics = topics.toBreeze.toDenseMatrix Range(0, k).map { topicIndex => @@ -210,6 +222,9 @@ class LocalLDAModel private[clustering] ( override protected def formatVersion = "1.0" + /** + * @since 1.5.0 + */ override def save(sc: SparkContext, path: String): Unit = { LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration, gammaShape) @@ -223,12 +238,16 @@ class LocalLDAModel private[clustering] ( * * @param documents test corpus to use for calculating log likelihood * @return variational lower bound on the log likelihood of the entire corpus + * @since 1.5.0 */ def logLikelihood(documents: RDD[(Long, Vector)]): Double = logLikelihoodBound(documents, docConcentration, topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k, vocabSize) - /** Java-friendly version of [[logLikelihood]] */ + /** + * Java-friendly version of [[logLikelihood]] + * @since 1.5.0 + */ def logLikelihood(documents: JavaPairRDD[java.lang.Long, Vector]): Double = { logLikelihood(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) } @@ -239,6 +258,7 @@ class LocalLDAModel private[clustering] ( * * @param documents test corpus to use for calculating perplexity * @return Variational upper bound on log perplexity per token. + * @since 1.5.0 */ def logPerplexity(documents: RDD[(Long, Vector)]): Double = { val corpusTokenCount = documents @@ -247,7 +267,9 @@ class LocalLDAModel private[clustering] ( -logLikelihood(documents) / corpusTokenCount } - /** Java-friendly version of [[logPerplexity]] */ + /** Java-friendly version of [[logPerplexity]] + * @since 1.5.0 + */ def logPerplexity(documents: JavaPairRDD[java.lang.Long, Vector]): Double = { logPerplexity(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) } @@ -325,6 +347,7 @@ class LocalLDAModel private[clustering] ( * for each document. * @param documents documents to predict topic mixture distributions for * @return An RDD of (document ID, topic mixture distribution for document) + * @since 1.3.0 */ // TODO: declare in LDAModel and override once implemented in DistributedLDAModel def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = { @@ -351,7 +374,10 @@ class LocalLDAModel private[clustering] ( } } - /** Java-friendly version of [[topicDistributions]] */ + /** + * Java-friendly version of [[topicDistributions]] + * @since 1.4.1 + */ def topicDistributions( documents: JavaPairRDD[java.lang.Long, Vector]): JavaPairRDD[java.lang.Long, Vector] = { val distributions = topicDistributions(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) @@ -425,6 +451,9 @@ object LocalLDAModel extends Loader[LocalLDAModel] { } } + /** + * @since 1.5.0 + */ override def load(sc: SparkContext, path: String): LocalLDAModel = { val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path) implicit val formats = DefaultFormats @@ -481,6 +510,7 @@ class DistributedLDAModel private[clustering] ( * Convert model to a local model. * The local model stores the inferred topics but not the topic distributions for training * documents. + * @since 1.3.0 */ def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix, docConcentration, topicConcentration, gammaShape) @@ -491,6 +521,7 @@ class DistributedLDAModel private[clustering] ( * No guarantees are given about the ordering of the topics. * * WARNING: This matrix is collected from an RDD. Beware memory usage when vocabSize, k are large. + * @since 1.3.0 */ override lazy val topicsMatrix: Matrix = { // Collect row-major topics @@ -510,6 +541,9 @@ class DistributedLDAModel private[clustering] ( Matrices.fromBreeze(brzTopics) } + /** + * @since 1.3.0 + */ override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = { val numTopics = k // Note: N_k is not needed to find the top terms, but it is needed to normalize weights @@ -548,6 +582,7 @@ class DistributedLDAModel private[clustering] ( * @return Array over topics. Each element represent as a pair of matching arrays: * (IDs for the documents, weights of the topic in these documents). * For each topic, documents are sorted in order of decreasing topic weights. + * @since 1.5.0 */ def topDocumentsPerTopic(maxDocumentsPerTopic: Int): Array[(Array[Long], Array[Double])] = { val numTopics = k @@ -587,6 +622,7 @@ class DistributedLDAModel private[clustering] ( * - This excludes the prior; for that, use [[logPrior]]. * - Even with [[logPrior]], this is NOT the same as the data log likelihood given the * hyperparameters. + * @since 1.3.0 */ lazy val logLikelihood: Double = { // TODO: generalize this for asymmetric (non-scalar) alpha @@ -612,7 +648,8 @@ class DistributedLDAModel private[clustering] ( /** * Log probability of the current parameter estimate: - * log P(topics, topic distributions for docs | alpha, eta) + * log P(topics, topic distributions for docs | alpha, eta) + * @since 1.3.0 */ lazy val logPrior: Double = { // TODO: generalize this for asymmetric (non-scalar) alpha @@ -644,6 +681,7 @@ class DistributedLDAModel private[clustering] ( * ("theta_doc"). * * @return RDD of (document ID, topic distribution) pairs + * @since 1.3.0 */ def topicDistributions: RDD[(Long, Vector)] = { graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) => @@ -651,7 +689,10 @@ class DistributedLDAModel private[clustering] ( } } - /** Java-friendly version of [[topicDistributions]] */ + /** + * Java-friendly version of [[topicDistributions]] + * @since 1.4.1 + */ def javaTopicDistributions: JavaPairRDD[java.lang.Long, Vector] = { JavaPairRDD.fromRDD(topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]]) } @@ -659,6 +700,7 @@ class DistributedLDAModel private[clustering] ( /** * For each document, return the top k weighted topics for that document and their weights. * @return RDD of (doc ID, topic indices, topic weights) + * @since 1.5.0 */ def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] = { graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) => @@ -673,7 +715,10 @@ class DistributedLDAModel private[clustering] ( } } - /** Java-friendly version of [[topTopicsPerDocument]] */ + /** + * Java-friendly version of [[topTopicsPerDocument]] + * @since 1.5.0 + */ def javaTopTopicsPerDocument(k: Int): JavaRDD[(java.lang.Long, Array[Int], Array[Double])] = { val topics = topTopicsPerDocument(k) topics.asInstanceOf[RDD[(java.lang.Long, Array[Int], Array[Double])]].toJavaRDD() @@ -684,6 +729,10 @@ class DistributedLDAModel private[clustering] ( override protected def formatVersion = "1.0" + /** + * Java-friendly version of [[topicDistributions]] + * @since 1.5.0 + */ override def save(sc: SparkContext, path: String): Unit = { DistributedLDAModel.SaveLoadV1_0.save( sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration, @@ -784,6 +833,9 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { } + /** + * @since 1.5.0 + */ override def load(sc: SparkContext, path: String): DistributedLDAModel = { val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path) implicit val formats = DefaultFormats diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index a0008f9c99ad7..360241c8081ac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -35,6 +35,7 @@ import org.apache.spark.rdd.RDD * * An LDAOptimizer specifies which optimization/learning/inference algorithm to use, and it can * hold optimizer-specific parameters for users to set. + * @since 1.4.0 */ @DeveloperApi sealed trait LDAOptimizer { @@ -73,7 +74,7 @@ sealed trait LDAOptimizer { * - Paper which clearly explains several algorithms, including EM: * Asuncion, Welling, Smyth, and Teh. * "On Smoothing and Inference for Topic Models." UAI, 2009. - * + * @since 1.4.0 */ @DeveloperApi final class EMLDAOptimizer extends LDAOptimizer { @@ -225,6 +226,7 @@ final class EMLDAOptimizer extends LDAOptimizer { * * Original Online LDA paper: * Hoffman, Blei and Bach, "Online Learning for Latent Dirichlet Allocation." NIPS, 2010. + * @since 1.4.0 */ @DeveloperApi final class OnlineLDAOptimizer extends LDAOptimizer { @@ -274,6 +276,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { /** * A (positive) learning parameter that downweights early iterations. Larger values make early * iterations count less. + * @since 1.4.0 */ def getTau0: Double = this.tau0 @@ -281,6 +284,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * A (positive) learning parameter that downweights early iterations. Larger values make early * iterations count less. * Default: 1024, following the original Online LDA paper. + * @since 1.4.0 */ def setTau0(tau0: Double): this.type = { require(tau0 > 0, s"LDA tau0 must be positive, but was set to $tau0") @@ -290,6 +294,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { /** * Learning rate: exponential decay rate + * @since 1.4.0 */ def getKappa: Double = this.kappa @@ -297,6 +302,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * Learning rate: exponential decay rate---should be between * (0.5, 1.0] to guarantee asymptotic convergence. * Default: 0.51, based on the original Online LDA paper. + * @since 1.4.0 */ def setKappa(kappa: Double): this.type = { require(kappa >= 0, s"Online LDA kappa must be nonnegative, but was set to $kappa") @@ -306,6 +312,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { /** * Mini-batch fraction, which sets the fraction of document sampled and used in each iteration + * @since 1.4.0 */ def getMiniBatchFraction: Double = this.miniBatchFraction @@ -318,6 +325,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * maxIterations * miniBatchFraction >= 1. * * Default: 0.05, i.e., 5% of total documents. + * @since 1.4.0 */ def setMiniBatchFraction(miniBatchFraction: Double): this.type = { require(miniBatchFraction > 0.0 && miniBatchFraction <= 1.0, @@ -329,6 +337,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { /** * Optimize alpha, indicates whether alpha (Dirichlet parameter for document-topic distribution) * will be optimized during training. + * @since 1.5.0 */ def getOptimzeAlpha: Boolean = this.optimizeAlpha @@ -336,6 +345,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * Sets whether to optimize alpha parameter during training. * * Default: false + * @since 1.5.0 */ def setOptimzeAlpha(optimizeAlpha: Boolean): this.type = { this.optimizeAlpha = optimizeAlpha diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 407e43a024a2e..b4733ca975152 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -39,12 +39,16 @@ import org.apache.spark.{Logging, SparkContext, SparkException} * * @param k number of clusters * @param assignments an RDD of clustering [[PowerIterationClustering#Assignment]]s + * @since 1.3.0 */ @Experimental class PowerIterationClusteringModel( val k: Int, val assignments: RDD[PowerIterationClustering.Assignment]) extends Saveable with Serializable { + /** + * @since 1.4.0 + */ override def save(sc: SparkContext, path: String): Unit = { PowerIterationClusteringModel.SaveLoadV1_0.save(sc, this, path) } @@ -52,6 +56,9 @@ class PowerIterationClusteringModel( override protected def formatVersion: String = "1.0" } +/** + * @since 1.4.0 + */ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringModel] { override def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { PowerIterationClusteringModel.SaveLoadV1_0.load(sc, path) @@ -65,6 +72,9 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode private[clustering] val thisClassName = "org.apache.spark.mllib.clustering.PowerIterationClusteringModel" + /** + * @since 1.4.0 + */ def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = { val sqlContext = new SQLContext(sc) import sqlContext.implicits._ @@ -77,6 +87,9 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode dataRDD.write.parquet(Loader.dataPath(path)) } + /** + * @since 1.4.0 + */ def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { implicit val formats = DefaultFormats val sqlContext = new SQLContext(sc) @@ -120,13 +133,16 @@ class PowerIterationClustering private[clustering] ( import org.apache.spark.mllib.clustering.PowerIterationClustering._ - /** Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100, - * initMode: "random"}. + /** + * Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100, + * initMode: "random"}. + * @since 1.3.0 */ def this() = this(k = 2, maxIterations = 100, initMode = "random") /** * Set the number of clusters. + * @since 1.3.0 */ def setK(k: Int): this.type = { this.k = k @@ -135,6 +151,7 @@ class PowerIterationClustering private[clustering] ( /** * Set maximum number of iterations of the power iteration loop + * @since 1.3.0 */ def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations @@ -144,6 +161,7 @@ class PowerIterationClustering private[clustering] ( /** * Set the initialization mode. This can be either "random" to use a random vector * as vertex properties, or "degree" to use normalized sum similarities. Default: random. + * @since 1.3.0 */ def setInitializationMode(mode: String): this.type = { this.initMode = mode match { @@ -164,6 +182,7 @@ class PowerIterationClustering private[clustering] ( * assume s,,ij,, = 0.0. * * @return a [[PowerIterationClusteringModel]] that contains the clustering result + * @since 1.5.0 */ def run(graph: Graph[Double, Double]): PowerIterationClusteringModel = { val w = normalize(graph) @@ -185,6 +204,7 @@ class PowerIterationClustering private[clustering] ( * assume s,,ij,, = 0.0. * * @return a [[PowerIterationClusteringModel]] that contains the clustering result + * @since 1.3.0 */ def run(similarities: RDD[(Long, Long, Double)]): PowerIterationClusteringModel = { val w = normalize(similarities) @@ -197,6 +217,7 @@ class PowerIterationClustering private[clustering] ( /** * A Java-friendly version of [[PowerIterationClustering.run]]. + * @since 1.3.0 */ def run(similarities: JavaRDD[(java.lang.Long, java.lang.Long, java.lang.Double)]) : PowerIterationClusteringModel = { @@ -221,6 +242,9 @@ class PowerIterationClustering private[clustering] ( } } +/** + * @since 1.3.0 + */ @Experimental object PowerIterationClustering extends Logging { @@ -229,6 +253,7 @@ object PowerIterationClustering extends Logging { * Cluster assignment. * @param id node id * @param cluster assigned cluster id + * @since 1.3.0 */ @Experimental case class Assignment(id: Long, cluster: Int) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index d9b34cec64894..a915804b02c90 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -63,6 +63,7 @@ import org.apache.spark.util.random.XORShiftRandom * such that at time t + h the discount applied to the data from t is 0.5. * The definition remains the same whether the time unit is given * as batches or points. + * @since 1.2.0 * */ @Experimental @@ -70,7 +71,10 @@ class StreamingKMeansModel( override val clusterCenters: Array[Vector], val clusterWeights: Array[Double]) extends KMeansModel(clusterCenters) with Logging { - /** Perform a k-means update on a batch of data. */ + /** + * Perform a k-means update on a batch of data. + * @since 1.2.0 + */ def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = { // find nearest cluster to each point @@ -82,6 +86,7 @@ class StreamingKMeansModel( (p1._1, p1._2 + p2._2) } val dim = clusterCenters(0).size + val pointStats: Array[(Int, (Vector, Long))] = closest .aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs) .collect() @@ -161,6 +166,7 @@ class StreamingKMeansModel( * .setRandomCenters(5, 100.0) * .trainOn(DStream) * }}} + * @since 1.2.0 */ @Experimental class StreamingKMeans( @@ -168,23 +174,33 @@ class StreamingKMeans( var decayFactor: Double, var timeUnit: String) extends Logging with Serializable { + /** @since 1.2.0 */ def this() = this(2, 1.0, StreamingKMeans.BATCHES) protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null) - /** Set the number of clusters. */ + /** + * Set the number of clusters. + * @since 1.2.0 + */ def setK(k: Int): this.type = { this.k = k this } - /** Set the decay factor directly (for forgetful algorithms). */ + /** + * Set the decay factor directly (for forgetful algorithms). + * @since 1.2.0 + */ def setDecayFactor(a: Double): this.type = { this.decayFactor = a this } - /** Set the half life and time unit ("batches" or "points") for forgetful algorithms. */ + /** + * Set the half life and time unit ("batches" or "points") for forgetful algorithms. + * @since 1.2.0 + */ def setHalfLife(halfLife: Double, timeUnit: String): this.type = { if (timeUnit != StreamingKMeans.BATCHES && timeUnit != StreamingKMeans.POINTS) { throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit) @@ -195,7 +211,10 @@ class StreamingKMeans( this } - /** Specify initial centers directly. */ + /** + * Specify initial centers directly. + * @since 1.2.0 + */ def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = { model = new StreamingKMeansModel(centers, weights) this @@ -207,6 +226,7 @@ class StreamingKMeans( * @param dim Number of dimensions * @param weight Weight for each center * @param seed Random seed + * @since 1.2.0 */ def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = { val random = new XORShiftRandom(seed) @@ -216,7 +236,10 @@ class StreamingKMeans( this } - /** Return the latest model. */ + /** + * Return the latest model. + * @since 1.2.0 + */ def latestModel(): StreamingKMeansModel = { model } @@ -228,6 +251,7 @@ class StreamingKMeans( * and updates the model using each batch of data from the stream. * * @param data DStream containing vector data + * @since 1.2.0 */ def trainOn(data: DStream[Vector]) { assertInitialized() @@ -236,7 +260,10 @@ class StreamingKMeans( } } - /** Java-friendly version of `trainOn`. */ + /** + * Java-friendly version of `trainOn`. + * @since 1.4.0 + */ def trainOn(data: JavaDStream[Vector]): Unit = trainOn(data.dstream) /** @@ -244,13 +271,17 @@ class StreamingKMeans( * * @param data DStream containing vector data * @return DStream containing predictions + * @since 1.2.0 */ def predictOn(data: DStream[Vector]): DStream[Int] = { assertInitialized() data.map(model.predict) } - /** Java-friendly version of `predictOn`. */ + /** + * Java-friendly version of `predictOn`. + * @since 1.4.0 + */ def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Integer] = { JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Integer]]) } @@ -261,13 +292,17 @@ class StreamingKMeans( * @param data DStream containing (key, feature vector) pairs * @tparam K key type * @return DStream containing the input keys and the predictions as values + * @since 1.2.0 */ def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Int)] = { assertInitialized() data.mapValues(model.predict) } - /** Java-friendly version of `predictOnValues`. */ + /** + * Java-friendly version of `predictOnValues`. + * @since 1.4.0 + */ def predictOnValues[K]( data: JavaPairDStream[K, Vector]): JavaPairDStream[K, java.lang.Integer] = { implicit val tag = fakeClassTag[K] From f3391ff2b8b9c1f1308755dc223776692e3b7725 Mon Sep 17 00:00:00 2001 From: Joshi Date: Wed, 19 Aug 2015 21:23:02 +0100 Subject: [PATCH 002/802] [SPARK-8889] [CORE] Fix for OOM for graph creation Fix for OOM for graph creation Author: Joshi Author: Rekha Joshi Closes #7602 from rekhajoshm/SPARK-8889. --- .../spark/ui/scope/RDDOperationGraph.scala | 23 +++++------ .../org/apache/spark/ui/UISeleniumSuite.scala | 39 +++++++++++++++++++ 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index ffea9817c0b08..81f168a447ead 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -18,7 +18,7 @@ package org.apache.spark.ui.scope import scala.collection.mutable -import scala.collection.mutable.ListBuffer +import scala.collection.mutable.{StringBuilder, ListBuffer} import org.apache.spark.Logging import org.apache.spark.scheduler.StageInfo @@ -167,7 +167,7 @@ private[ui] object RDDOperationGraph extends Logging { def makeDotFile(graph: RDDOperationGraph): String = { val dotFile = new StringBuilder dotFile.append("digraph G {\n") - dotFile.append(makeDotSubgraph(graph.rootCluster, indent = " ")) + makeDotSubgraph(dotFile, graph.rootCluster, indent = " ") graph.edges.foreach { edge => dotFile.append(s""" ${edge.fromId}->${edge.toId};\n""") } dotFile.append("}") val result = dotFile.toString() @@ -180,18 +180,19 @@ private[ui] object RDDOperationGraph extends Logging { s"""${node.id} [label="${node.name} [${node.id}]"]""" } - /** Return the dot representation of a subgraph in an RDDOperationGraph. */ - private def makeDotSubgraph(cluster: RDDOperationCluster, indent: String): String = { - val subgraph = new StringBuilder - subgraph.append(indent + s"subgraph cluster${cluster.id} {\n") - subgraph.append(indent + s""" label="${cluster.name}";\n""") + /** Update the dot representation of the RDDOperationGraph in cluster to subgraph. */ + private def makeDotSubgraph( + subgraph: StringBuilder, + cluster: RDDOperationCluster, + indent: String): Unit = { + subgraph.append(indent).append(s"subgraph cluster${cluster.id} {\n") + subgraph.append(indent).append(s""" label="${cluster.name}";\n""") cluster.childNodes.foreach { node => - subgraph.append(indent + s" ${makeDotNode(node)};\n") + subgraph.append(indent).append(s" ${makeDotNode(node)};\n") } cluster.childClusters.foreach { cscope => - subgraph.append(makeDotSubgraph(cscope, indent + " ")) + makeDotSubgraph(subgraph, cscope, indent + " ") } - subgraph.append(indent + "}\n") - subgraph.toString() + subgraph.append(indent).append("}\n") } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 3aa672f8b713c..69888b2694bae 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.ui import java.net.{HttpURLConnection, URL} import javax.servlet.http.{HttpServletResponse, HttpServletRequest} +import scala.io.Source import scala.collection.JavaConversions._ import scala.xml.Node @@ -603,6 +604,44 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } } + test("job stages should have expected dotfile under DAG visualization") { + withSpark(newSparkContext()) { sc => + // Create a multi-stage job + val rdd = + sc.parallelize(Seq(1, 2, 3)).map(identity).groupBy(identity).map(identity).groupBy(identity) + rdd.count() + + val stage0 = Source.fromURL(sc.ui.get.appUIAddress + + "/stages/stage/?id=0&attempt=0&expandDagViz=true").mkString + assert(stage0.contains("digraph G {\n subgraph clusterstage_0 {\n " + + "label="Stage 0";\n subgraph ")) + assert(stage0.contains("{\n label="parallelize";\n " + + "0 [label="ParallelCollectionRDD [0]"];\n }")) + assert(stage0.contains("{\n label="map";\n " + + "1 [label="MapPartitionsRDD [1]"];\n }")) + assert(stage0.contains("{\n label="groupBy";\n " + + "2 [label="MapPartitionsRDD [2]"];\n }")) + + val stage1 = Source.fromURL(sc.ui.get.appUIAddress + + "/stages/stage/?id=1&attempt=0&expandDagViz=true").mkString + assert(stage1.contains("digraph G {\n subgraph clusterstage_1 {\n " + + "label="Stage 1";\n subgraph ")) + assert(stage1.contains("{\n label="groupBy";\n " + + "3 [label="ShuffledRDD [3]"];\n }")) + assert(stage1.contains("{\n label="map";\n " + + "4 [label="MapPartitionsRDD [4]"];\n }")) + assert(stage1.contains("{\n label="groupBy";\n " + + "5 [label="MapPartitionsRDD [5]"];\n }")) + + val stage2 = Source.fromURL(sc.ui.get.appUIAddress + + "/stages/stage/?id=2&attempt=0&expandDagViz=true").mkString + assert(stage2.contains("digraph G {\n subgraph clusterstage_2 {\n " + + "label="Stage 2";\n subgraph ")) + assert(stage2.contains("{\n label="groupBy";\n " + + "6 [label="ShuffledRDD [6]"];\n }")) + } + } + def goToUi(sc: SparkContext, path: String): Unit = { goToUi(sc.ui.get, path) } From e05da5cb5ea253e6372f648fc8203204f2a8df8d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 19 Aug 2015 13:43:04 -0700 Subject: [PATCH 003/802] [SPARK-10107] [SQL] fix NPE in format_number Author: Davies Liu Closes #8305 from davies/format_number. --- .../spark/sql/catalyst/expressions/stringOperations.scala | 2 +- .../scala/org/apache/spark/sql/StringFunctionsSuite.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 134f1aa2af9a8..ca044d3e95e38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -1306,8 +1306,8 @@ case class FormatNumber(x: Expression, d: Expression) $df $dFormat = new $df($pattern.toString()); $lastDValue = $d; $numberFormat.applyPattern($dFormat.toPattern()); - ${ev.primitive} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); } + ${ev.primitive} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); } else { ${ev.primitive} = null; ${ev.isNull} = true; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index cc95eede005d7..b91438baea06f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -348,9 +348,9 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { // it will still use the interpretProjection if projection follows by a LocalRelation, // hence we add a filter operator. // See the optimizer rule `ConvertToLocalRelation` - val df2 = Seq((5L, 4), (4L, 3), (3L, 2)).toDF("a", "b") + val df2 = Seq((5L, 4), (4L, 3), (4L, 3), (4L, 3), (3L, 2)).toDF("a", "b") checkAnswer( df2.filter("b>0").selectExpr("format_number(a, b)"), - Row("5.0000") :: Row("4.000") :: Row("3.00") :: Nil) + Row("5.0000") :: Row("4.000") :: Row("4.000") :: Row("4.000") :: Row("3.00") :: Nil) } } From 08887369c890e0dd87eb8b34e8c32bb03307bf24 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 19 Aug 2015 13:56:40 -0700 Subject: [PATCH 004/802] [SPARK-10073] [SQL] Python withColumn should replace the old column DataFrame.withColumn in Python should be consistent with the Scala one (replacing the existing column that has the same name). cc marmbrus Author: Davies Liu Closes #8300 from davies/with_column. --- python/pyspark/sql/dataframe.py | 12 ++++++------ python/pyspark/sql/tests.py | 4 ++++ .../main/scala/org/apache/spark/sql/DataFrame.scala | 3 ++- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index da742d7ce7d13..025811f519293 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1202,7 +1202,9 @@ def freqItems(self, cols, support=None): @ignore_unicode_prefix @since(1.3) def withColumn(self, colName, col): - """Returns a new :class:`DataFrame` by adding a column. + """ + Returns a new :class:`DataFrame` by adding a column or replacing the + existing column that has the same name. :param colName: string, name of the new column. :param col: a :class:`Column` expression for the new column. @@ -1210,7 +1212,8 @@ def withColumn(self, colName, col): >>> df.withColumn('age2', df.age + 2).collect() [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)] """ - return self.select('*', col.alias(colName)) + assert isinstance(col, Column), "col should be Column" + return DataFrame(self._jdf.withColumn(colName, col._jc), self.sql_ctx) @ignore_unicode_prefix @since(1.3) @@ -1223,10 +1226,7 @@ def withColumnRenamed(self, existing, new): >>> df.withColumnRenamed('age', 'age2').collect() [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')] """ - cols = [Column(_to_java_column(c)).alias(new) - if c == existing else c - for c in self.columns] - return self.select(*cols) + return DataFrame(self._jdf.withColumnRenamed(existing, new), self.sql_ctx) @since(1.4) @ignore_unicode_prefix diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 13cf647b66da8..aacfb34c77618 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1035,6 +1035,10 @@ def test_capture_illegalargument_exception(self): self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values", lambda: df.select(sha2(df.a, 1024)).collect()) + def test_with_column_with_existing_name(self): + keys = self.df.withColumn("key", self.df.key).select("key").collect() + self.assertEqual([r.key for r in keys], list(range(100))) + class HiveContextSQLTests(ReusedPySparkTestCase): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index fd0ead4401193..d6688b24ae7d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1133,7 +1133,8 @@ class DataFrame private[sql]( ///////////////////////////////////////////////////////////////////////////// /** - * Returns a new [[DataFrame]] by adding a column. + * Returns a new [[DataFrame]] by adding a column or replacing the existing column that has + * the same name. * @group dfops * @since 1.3.0 */ From 21bdbe9fe69be47be562de24216a469e5ee64c7b Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 19 Aug 2015 13:57:52 -0700 Subject: [PATCH 005/802] [SPARK-9627] [SQL] Stops using Scala runtime reflection in DictionaryEncoding `DictionaryEncoding` uses Scala runtime reflection to avoid boxing costs while building the directory array. However, this code path may hit [SI-6240] [1] and throw exception. [1]: https://issues.scala-lang.org/browse/SI-6240 Author: Cheng Lian Closes #8306 from liancheng/spark-9627/in-memory-cache-scala-reflection. --- .../sql/columnar/InMemoryColumnarTableScan.scala | 1 - .../columnar/compression/compressionSchemes.scala | 15 ++++----------- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 45f15fd04d4e2..66d429bc06198 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -120,7 +120,6 @@ private[sql] case class InMemoryRelation( new Iterator[CachedBatch] { def next(): CachedBatch = { val columnBuilders = output.map { attribute => - val columnType = ColumnType(attribute.dataType) ColumnBuilder(attribute.dataType, batchSize, attribute.name, useCompression) }.toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala index c91d960a0932b..ca910a99db082 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala @@ -270,20 +270,13 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { class Decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) extends compression.Decoder[T] { - private val dictionary = { - // TODO Can we clean up this mess? Maybe move this to `DataType`? - implicit val classTag = { - val mirror = runtimeMirror(Utils.getSparkClassLoader) - ClassTag[T#InternalType](mirror.runtimeClass(columnType.scalaTag.tpe)) - } - - Array.fill(buffer.getInt()) { - columnType.extract(buffer) - } + private val dictionary: Array[Any] = { + val elementNum = buffer.getInt() + Array.fill[Any](elementNum)(columnType.extract(buffer).asInstanceOf[Any]) } override def next(row: MutableRow, ordinal: Int): Unit = { - columnType.setField(row, ordinal, dictionary(buffer.getShort())) + columnType.setField(row, ordinal, dictionary(buffer.getShort()).asInstanceOf[T#InternalType]) } override def hasNext: Boolean = buffer.hasRemaining From 1f4c4fe6dfd8cc52b5fddfd67a31a77edbb1a036 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 19 Aug 2015 14:03:47 -0700 Subject: [PATCH 006/802] [SPARK-10090] [SQL] fix decimal scale of division We should rounding the result of multiply/division of decimal to expected precision/scale, also check overflow. Author: Davies Liu Closes #8287 from davies/decimal_division. --- .../catalyst/analysis/HiveTypeCoercion.scala | 28 +++++---- .../spark/sql/catalyst/expressions/Cast.scala | 32 +++++----- .../expressions/decimalFunctions.scala | 38 ++++++++++- .../org/apache/spark/sql/types/Decimal.scala | 4 +- .../expressions/DecimalExpressionSuite.scala | 63 +++++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 23 ++++++- 6 files changed, 157 insertions(+), 31 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 8581d6b496c15..62c27ee0b9ee0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -371,8 +371,8 @@ object HiveTypeCoercion { DecimalType.bounded(range + scale, scale) } - private def changePrecision(e: Expression, dataType: DataType): Expression = { - ChangeDecimalPrecision(Cast(e, dataType)) + private def promotePrecision(e: Expression, dataType: DataType): Expression = { + PromotePrecision(Cast(e, dataType)) } def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { @@ -383,36 +383,42 @@ object HiveTypeCoercion { case e if !e.childrenResolved => e // Skip nodes who is already promoted - case e: BinaryArithmetic if e.left.isInstanceOf[ChangeDecimalPrecision] => e + case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - Add(changePrecision(e1, dt), changePrecision(e2, dt)) + CheckOverflow(Add(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - Subtract(changePrecision(e1, dt), changePrecision(e2, dt)) + CheckOverflow(Subtract(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val dt = DecimalType.bounded(p1 + p2 + 1, s1 + s2) - Multiply(changePrecision(e1, dt), changePrecision(e2, dt)) + val resultType = DecimalType.bounded(p1 + p2 + 1, s1 + s2) + val widerType = widerDecimalType(p1, s1, p2, s2) + CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), + resultType) case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val dt = DecimalType.bounded(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1)) - Divide(changePrecision(e1, dt), changePrecision(e2, dt)) + val resultType = DecimalType.bounded(p1 - s1 + s2 + max(6, s1 + p2 + 1), + max(6, s1 + p2 + 1)) + val widerType = widerDecimalType(p1, s1, p2, s2) + CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), + resultType) case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) - Cast(Remainder(changePrecision(e1, widerType), changePrecision(e2, widerType)), + CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) - Cast(Pmod(changePrecision(e1, widerType), changePrecision(e2, widerType)), resultType) + CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), + resultType) case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 616b9e0e65b78..2db954257be35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -447,7 +447,7 @@ case class Cast(child: Expression, dataType: DataType) case StringType => castToStringCode(from, ctx) case BinaryType => castToBinaryCode(from) case DateType => castToDateCode(from, ctx) - case decimal: DecimalType => castToDecimalCode(from, decimal) + case decimal: DecimalType => castToDecimalCode(from, decimal, ctx) case TimestampType => castToTimestampCode(from, ctx) case CalendarIntervalType => castToIntervalCode(from) case BooleanType => castToBooleanCode(from) @@ -528,14 +528,18 @@ case class Cast(child: Expression, dataType: DataType) } """ - private[this] def castToDecimalCode(from: DataType, target: DecimalType): CastFunction = { + private[this] def castToDecimalCode( + from: DataType, + target: DecimalType, + ctx: CodeGenContext): CastFunction = { + val tmp = ctx.freshName("tmpDecimal") from match { case StringType => (c, evPrim, evNull) => s""" try { - Decimal tmpDecimal = Decimal.apply(new java.math.BigDecimal($c.toString())); - ${changePrecision("tmpDecimal", target, evPrim, evNull)} + Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString())); + ${changePrecision(tmp, target, evPrim, evNull)} } catch (java.lang.NumberFormatException e) { $evNull = true; } @@ -543,8 +547,8 @@ case class Cast(child: Expression, dataType: DataType) case BooleanType => (c, evPrim, evNull) => s""" - Decimal tmpDecimal = $c ? Decimal.apply(1) : Decimal.apply(0); - ${changePrecision("tmpDecimal", target, evPrim, evNull)} + Decimal $tmp = $c ? Decimal.apply(1) : Decimal.apply(0); + ${changePrecision(tmp, target, evPrim, evNull)} """ case DateType => // date can't cast to decimal in Hive @@ -553,29 +557,29 @@ case class Cast(child: Expression, dataType: DataType) // Note that we lose precision here. (c, evPrim, evNull) => s""" - Decimal tmpDecimal = Decimal.apply( + Decimal $tmp = Decimal.apply( scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)})); - ${changePrecision("tmpDecimal", target, evPrim, evNull)} + ${changePrecision(tmp, target, evPrim, evNull)} """ case DecimalType() => (c, evPrim, evNull) => s""" - Decimal tmpDecimal = $c.clone(); - ${changePrecision("tmpDecimal", target, evPrim, evNull)} + Decimal $tmp = $c.clone(); + ${changePrecision(tmp, target, evPrim, evNull)} """ case x: IntegralType => (c, evPrim, evNull) => s""" - Decimal tmpDecimal = Decimal.apply((long) $c); - ${changePrecision("tmpDecimal", target, evPrim, evNull)} + Decimal $tmp = Decimal.apply((long) $c); + ${changePrecision(tmp, target, evPrim, evNull)} """ case x: FractionalType => // All other numeric types can be represented precisely as Doubles (c, evPrim, evNull) => s""" try { - Decimal tmpDecimal = Decimal.apply(scala.math.BigDecimal.valueOf((double) $c)); - ${changePrecision("tmpDecimal", target, evPrim, evNull)} + Decimal $tmp = Decimal.apply(scala.math.BigDecimal.valueOf((double) $c)); + ${changePrecision(tmp, target, evPrim, evNull)} } catch (java.lang.NumberFormatException e) { $evNull = true; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index adb33e4c8d4a1..b7be12f7aa741 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -66,10 +66,44 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un * An expression used to wrap the children when promote the precision of DecimalType to avoid * promote multiple times. */ -case class ChangeDecimalPrecision(child: Expression) extends UnaryExpression { +case class PromotePrecision(child: Expression) extends UnaryExpression { override def dataType: DataType = child.dataType override def eval(input: InternalRow): Any = child.eval(input) override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" - override def prettyName: String = "change_decimal_precision" + override def prettyName: String = "promote_precision" +} + +/** + * Rounds the decimal to given scale and check whether the decimal can fit in provided precision + * or not, returns null if not. + */ +case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression { + + override def nullable: Boolean = true + + override def nullSafeEval(input: Any): Any = { + val d = input.asInstanceOf[Decimal].clone() + if (d.changePrecision(dataType.precision, dataType.scale)) { + d + } else { + null + } + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, eval => { + val tmp = ctx.freshName("tmp") + s""" + | Decimal $tmp = $eval.clone(); + | if ($tmp.changePrecision(${dataType.precision}, ${dataType.scale})) { + | ${ev.primitive} = $tmp; + | } else { + | ${ev.isNull} = true; + | } + """.stripMargin + }) + } + + override def toString: String = s"CheckOverflow($child, $dataType)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index d95805c24521c..c988f1d1b972e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -267,7 +267,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) { Decimal(longVal + that.longVal, Math.max(precision, that.precision), scale) } else { - Decimal(toBigDecimal + that.toBigDecimal, precision, scale) + Decimal(toBigDecimal + that.toBigDecimal) } } @@ -275,7 +275,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) { Decimal(longVal - that.longVal, Math.max(precision, that.precision), scale) } else { - Decimal(toBigDecimal - that.toBigDecimal, precision, scale) + Decimal(toBigDecimal - that.toBigDecimal) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala new file mode 100644 index 0000000000000..511f0307901df --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.{LongType, DecimalType, Decimal} + + +class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("UnscaledValue") { + val d1 = Decimal("10.1") + checkEvaluation(UnscaledValue(Literal(d1)), 101L) + val d2 = Decimal(101, 3, 1) + checkEvaluation(UnscaledValue(Literal(d2)), 101L) + checkEvaluation(UnscaledValue(Literal.create(null, DecimalType(2, 1))), null) + } + + test("MakeDecimal") { + checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1")) + checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null) + } + + test("PromotePrecision") { + val d1 = Decimal("10.1") + checkEvaluation(PromotePrecision(Literal(d1)), d1) + val d2 = Decimal(101, 3, 1) + checkEvaluation(PromotePrecision(Literal(d2)), d2) + checkEvaluation(PromotePrecision(Literal.create(null, DecimalType(2, 1))), null) + } + + test("CheckOverflow") { + val d1 = Decimal("10.1") + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 0)), Decimal("10")) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 1)), d1) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 2)), d1) + checkEvaluation(CheckOverflow(Literal(d1), DecimalType(4, 3)), null) + + val d2 = Decimal(101, 3, 1) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 0)), Decimal("10")) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 1)), d2) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 2)), d2) + checkEvaluation(CheckOverflow(Literal(d2), DecimalType(4, 3)), null) + + checkEvaluation(CheckOverflow(Literal.create(null, DecimalType(2, 1)), DecimalType(3, 2)), null) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c329fdb2a6bb1..141468ca00d67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,16 +17,17 @@ package org.apache.spark.sql +import java.math.MathContext import java.sql.Timestamp import org.apache.spark.AccumulatorSuite -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.DefaultParserDialect +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ +import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ @@ -1608,6 +1609,24 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("decimal precision with multiply/division") { + checkAnswer(sql("select 10.3 * 3.0"), Row(BigDecimal("30.90"))) + checkAnswer(sql("select 10.3000 * 3.0"), Row(BigDecimal("30.90000"))) + checkAnswer(sql("select 10.30000 * 30.0"), Row(BigDecimal("309.000000"))) + checkAnswer(sql("select 10.300000000000000000 * 3.000000000000000000"), + Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38)))) + checkAnswer(sql("select 10.300000000000000000 * 3.0000000000000000000"), + Row(null)) + + checkAnswer(sql("select 10.3 / 3.0"), Row(BigDecimal("3.433333"))) + checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333"))) + checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333"))) + checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"), + Row(BigDecimal("3.4333333333333333333333333333333333333", new MathContext(38)))) + checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"), + Row(null)) + } + test("external sorting updates peak execution memory") { withSQLConf((SQLConf.EXTERNAL_SORT.key, "true")) { val sc = sqlContext.sparkContext From f3ff4c41d2e32bd0f2419d1c9c68fcd0c2593e41 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 19 Aug 2015 14:15:28 -0700 Subject: [PATCH 007/802] [SPARK-9899] [SQL] Disables customized output committer when speculation is on Speculation hates direct output committer, as there are multiple corner cases that may cause data corruption and/or data loss. Please see this [PR comment] [1] for more details. [1]: https://github.com/apache/spark/pull/8191#issuecomment-131598385 Author: Cheng Lian Closes #8317 from liancheng/spark-9899/speculation-hates-direct-output-committer. --- .../datasources/WriterContainer.scala | 16 ++++++++- .../sql/sources/hadoopFsRelationSuites.scala | 34 +++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index e0147079e6997..78f48a5cd72c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -58,6 +58,9 @@ private[sql] abstract class BaseWriterContainer( // This is only used on driver side. @transient private val jobContext: JobContext = job + private val speculationEnabled: Boolean = + relation.sqlContext.sparkContext.conf.getBoolean("spark.speculation", defaultValue = false) + // The following fields are initialized and used on both driver and executor side. @transient protected var outputCommitter: OutputCommitter = _ @transient private var jobId: JobID = _ @@ -126,10 +129,21 @@ private[sql] abstract class BaseWriterContainer( // associated with the file output format since it is not safe to use a custom // committer for appending. For example, in S3, direct parquet output committer may // leave partial data in the destination dir when the the appending job fails. + // + // See SPARK-8578 for more details logInfo( - s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName} " + + s"Using default output committer ${defaultOutputCommitter.getClass.getCanonicalName} " + "for appending.") defaultOutputCommitter + } else if (speculationEnabled) { + // When speculation is enabled, it's not safe to use customized output committer classes, + // especially direct output committers (e.g. `DirectParquetOutputCommitter`). + // + // See SPARK-9899 for more details. + logInfo( + s"Using default output committer ${defaultOutputCommitter.getClass.getCanonicalName} " + + "because spark.speculation is configured to be true.") + defaultOutputCommitter } else { val committerClass = context.getConfiguration.getClass( SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 8d0d9218ddd6a..5bbca14bad320 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -570,6 +570,40 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { df.write.format(dataSourceName).partitionBy("c", "d", "e").saveAsTable("t") } } + + test("SPARK-9899 Disable customized output committer when speculation is on") { + val clonedConf = new Configuration(configuration) + val speculationEnabled = + sqlContext.sparkContext.conf.getBoolean("spark.speculation", defaultValue = false) + + try { + withTempPath { dir => + // Enables task speculation + sqlContext.sparkContext.conf.set("spark.speculation", "true") + + // Uses a customized output committer which always fails + configuration.set( + SQLConf.OUTPUT_COMMITTER_CLASS.key, + classOf[AlwaysFailOutputCommitter].getName) + + // Code below shouldn't throw since customized output committer should be disabled. + val df = sqlContext.range(10).coalesce(1) + df.write.format(dataSourceName).save(dir.getCanonicalPath) + checkAnswer( + sqlContext + .read + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .load(dir.getCanonicalPath), + df) + } + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + configuration.clear() + clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + sqlContext.sparkContext.conf.set("spark.speculation", speculationEnabled.toString) + } + } } // This class is used to test SPARK-8578. We should not use any custom output committer when From 373a376c04320aab228b5c385e2b788809877d3e Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 19 Aug 2015 14:31:51 -0700 Subject: [PATCH 008/802] [SPARK-10083] [SQL] CaseWhen should support type coercion of DecimalType and FractionalType create t1 (a decimal(7, 2), b long); select case when 1=1 then a else 1.0 end from t1; select case when 1=1 then a else b end from t1; Author: Daoyuan Wang Closes #8270 from adrian-wang/casewhenfractional. --- .../sql/catalyst/analysis/HiveTypeCoercion.scala | 4 ++-- .../sql/catalyst/analysis/HiveTypeCoercionSuite.scala | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 62c27ee0b9ee0..f2f2ba2f96552 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -605,7 +605,7 @@ object HiveTypeCoercion { def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual => logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}") - val maybeCommonType = findTightestCommonTypeAndPromoteToString(c.valueTypes) + val maybeCommonType = findWiderCommonType(c.valueTypes) maybeCommonType.map { commonType => val castedBranches = c.branches.grouped(2).map { case Seq(when, value) if value.dataType != commonType => @@ -622,7 +622,7 @@ object HiveTypeCoercion { case c: CaseKeyWhen if c.childrenResolved && !c.resolved => val maybeCommonType = - findTightestCommonTypeAndPromoteToString((c.key +: c.whenList).map(_.dataType)) + findWiderCommonType((c.key +: c.whenList).map(_.dataType)) maybeCommonType.map { commonType => val castedBranches = c.branches.grouped(2).map { case Seq(whenExpr, thenExpr) if whenExpr.dataType != commonType => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index cbdf453f600ab..6f33ab733b615 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -285,6 +285,17 @@ class HiveTypeCoercionSuite extends PlanTest { CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) ) + ruleTest(HiveTypeCoercion.CaseWhenCoercion, + CaseWhen(Seq(Literal(true), Literal(1.2), Literal.create(1, DecimalType(7, 2)))), + CaseWhen(Seq( + Literal(true), Literal(1.2), Cast(Literal.create(1, DecimalType(7, 2)), DoubleType))) + ) + ruleTest(HiveTypeCoercion.CaseWhenCoercion, + CaseWhen(Seq(Literal(true), Literal(100L), Literal.create(1, DecimalType(7, 2)))), + CaseWhen(Seq( + Literal(true), Cast(Literal(100L), DecimalType(22, 2)), + Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2)))) + ) } test("type coercion simplification for equal to") { From e0dd1309ac248375f429639801923570f14de18d Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 19 Aug 2015 14:33:32 -0700 Subject: [PATCH 009/802] [SPARK-10119] [CORE] Fix isDynamicAllocationEnabled when config is expliticly disabled. Author: Marcelo Vanzin Closes #8316 from vanzin/SPARK-10119. --- .../main/scala/org/apache/spark/util/Utils.scala | 2 +- .../scala/org/apache/spark/util/UtilsSuite.scala | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index fddc24dbfc237..8313312226713 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2141,7 +2141,7 @@ private[spark] object Utils extends Logging { * the latter should override the former (SPARK-9092). */ def isDynamicAllocationEnabled(conf: SparkConf): Boolean = { - conf.contains("spark.dynamicAllocation.enabled") && + conf.getBoolean("spark.dynamicAllocation.enabled", false) && conf.getInt("spark.executor.instances", 0) == 0 } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 8f7e402d5f2a6..1fb81ad565b41 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -720,4 +720,18 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(Utils.nanSafeCompareFloats(Float.PositiveInfinity, Float.NaN) === -1) assert(Utils.nanSafeCompareFloats(Float.NegativeInfinity, Float.NaN) === -1) } + + test("isDynamicAllocationEnabled") { + val conf = new SparkConf() + assert(Utils.isDynamicAllocationEnabled(conf) === false) + assert(Utils.isDynamicAllocationEnabled( + conf.set("spark.dynamicAllocation.enabled", "false")) === false) + assert(Utils.isDynamicAllocationEnabled( + conf.set("spark.dynamicAllocation.enabled", "true")) === true) + assert(Utils.isDynamicAllocationEnabled( + conf.set("spark.executor.instances", "1")) === false) + assert(Utils.isDynamicAllocationEnabled( + conf.set("spark.executor.instances", "0")) === true) + } + } From b0dbaec4f942a47afde3490b9339ad3bd187024d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 19 Aug 2015 15:04:56 -0700 Subject: [PATCH 010/802] [SPARK-6489] [SQL] add column pruning for Generate This PR takes over https://github.com/apache/spark/pull/5358 Author: Wenchen Fan Closes #8268 from cloud-fan/6489. --- .../sql/catalyst/expressions/generators.scala | 2 - .../sql/catalyst/optimizer/Optimizer.scala | 16 ++++ .../optimizer/ColumnPruningSuite.scala | 84 +++++++++++++++++++ 3 files changed, 100 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index d474853355e5b..c0845e1a0102f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import scala.collection.Map - import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 47b06cae15436..42457d5318b48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -165,6 +165,7 @@ object SetOperationPushDown extends Rule[LogicalPlan] { * * - Inserting Projections beneath the following operators: * - Aggregate + * - Generate * - Project <- Join * - LeftSemiJoin */ @@ -178,6 +179,21 @@ object ColumnPruning extends Rule[LogicalPlan] { case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = Project(a.references.toSeq, child)) + // Eliminate attributes that are not needed to calculate the Generate. + case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty => + g.copy(child = Project(g.references.toSeq, g.child)) + + case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) => + p.copy(child = g.copy(join = false)) + + case p @ Project(projectList, g: Generate) if g.join => + val neededChildOutput = p.references -- g.generatorOutput ++ g.references + if (neededChildOutput == g.child.outputSet) { + p + } else { + Project(projectList, g.copy(child = Project(neededChildOutput.toSeq, g.child))) + } + case p @ Project(projectList, a @ Aggregate(groupingExpressions, aggregateExpressions, child)) if (a.outputSet -- p.references).nonEmpty => Project( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala new file mode 100644 index 0000000000000..dbebcb86809de --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.Explode +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation, Generate, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.types.StringType + +class ColumnPruningSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("Column pruning", FixedPoint(100), + ColumnPruning) :: Nil + } + + test("Column pruning for Generate when Generate.join = false") { + val input = LocalRelation('a.int, 'b.array(StringType)) + + val query = Generate(Explode('b), false, false, None, 's.string :: Nil, input).analyze + val optimized = Optimize.execute(query) + + val correctAnswer = + Generate(Explode('b), false, false, None, 's.string :: Nil, + Project('b.attr :: Nil, input)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Column pruning for Generate when Generate.join = true") { + val input = LocalRelation('a.int, 'b.int, 'c.array(StringType)) + + val query = + Project(Seq('a, 's), + Generate(Explode('c), true, false, None, 's.string :: Nil, + input)).analyze + val optimized = Optimize.execute(query) + + val correctAnswer = + Project(Seq('a, 's), + Generate(Explode('c), true, false, None, 's.string :: Nil, + Project(Seq('a, 'c), + input))).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Turn Generate.join to false if possible") { + val input = LocalRelation('b.array(StringType)) + + val query = + Project(('s + 1).as("s+1") :: Nil, + Generate(Explode('b), true, false, None, 's.string :: Nil, + input)).analyze + val optimized = Optimize.execute(query) + + val correctAnswer = + Project(('s + 1).as("s+1") :: Nil, + Generate(Explode('b), false, false, None, 's.string :: Nil, + input)).analyze + + comparePlans(optimized, correctAnswer) + } + + // todo: add more tests for column pruning +} From 8e0a072f78b4902d5f7ccc6b15232ed202a117f9 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 19 Aug 2015 15:43:08 -0700 Subject: [PATCH 011/802] [SPARK-9895] User Guide for RFormula Feature Transformer mengxr Author: Eric Liang Closes #8293 from ericl/docs-2. --- docs/ml-features.md | 108 ++++++++++++++++++ .../apache/spark/ml/feature/RFormula.scala | 4 +- 2 files changed, 110 insertions(+), 2 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index d0e8eeb7a757e..6309db97be4d0 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1477,3 +1477,111 @@ print(output.select("features", "clicked").first()) +## RFormula + +`RFormula` selects columns specified by an [R model formula](https://stat.ethz.ch/R-manual/R-devel/library/stats/html/formula.html). It produces a vector column of features and a double column of labels. Like when formulas are used in R for linear regression, string input columns will be one-hot encoded, and numeric columns will be cast to doubles. If not already present in the DataFrame, the output label column will be created from the specified response variable in the formula. + +**Examples** + +Assume that we have a DataFrame with the columns `id`, `country`, `hour`, and `clicked`: + +~~~ +id | country | hour | clicked +---|---------|------|--------- + 7 | "US" | 18 | 1.0 + 8 | "CA" | 12 | 0.0 + 9 | "NZ" | 15 | 0.0 +~~~ + +If we use `RFormula` with a formula string of `clicked ~ country + hour`, which indicates that we want to +predict `clicked` based on `country` and `hour`, after transformation we should get the following DataFrame: + +~~~ +id | country | hour | clicked | features | label +---|---------|------|---------|------------------|------- + 7 | "US" | 18 | 1.0 | [0.0, 0.0, 18.0] | 1.0 + 8 | "CA" | 12 | 0.0 | [0.0, 1.0, 12.0] | 0.0 + 9 | "NZ" | 15 | 0.0 | [1.0, 0.0, 15.0] | 0.0 +~~~ + +
+
+ +[`RFormula`](api/scala/index.html#org.apache.spark.ml.feature.RFormula) takes an R formula string, and optional parameters for the names of its output columns. + +{% highlight scala %} +import org.apache.spark.ml.feature.RFormula + +val dataset = sqlContext.createDataFrame(Seq( + (7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0) +)).toDF("id", "country", "hour", "clicked") +val formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label") +val output = formula.fit(dataset).transform(dataset) +output.select("features", "label").show() +{% endhighlight %} +
+ +
+ +[`RFormula`](api/java/org/apache/spark/ml/feature/RFormula.html) takes an R formula string, and optional parameters for the names of its output columns. + +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.RFormula; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; + +StructType schema = createStructType(new StructField[] { + createStructField("id", IntegerType, false), + createStructField("country", StringType, false), + createStructField("hour", IntegerType, false), + createStructField("clicked", DoubleType, false) +}); +JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(7, "US", 18, 1.0), + RowFactory.create(8, "CA", 12, 0.0), + RowFactory.create(9, "NZ", 15, 0.0) +)); +DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + +RFormula formula = new RFormula() + .setFormula("clicked ~ country + hour") + .setFeaturesCol("features") + .setLabelCol("label"); + +DataFrame output = formula.fit(dataset).transform(dataset); +output.select("features", "label").show(); +{% endhighlight %} +
+ +
+ +[`RFormula`](api/python/pyspark.ml.html#pyspark.ml.feature.RFormula) takes an R formula string, and optional parameters for the names of its output columns. + +{% highlight python %} +from pyspark.ml.feature import RFormula + +dataset = sqlContext.createDataFrame( + [(7, "US", 18, 1.0), + (8, "CA", 12, 0.0), + (9, "NZ", 15, 0.0)], + ["id", "country", "hour", "clicked"]) +formula = RFormula( + formula="clicked ~ country + hour", + featuresCol="features", + labelCol="label") +output = formula.fit(dataset).transform(dataset) +output.select("features", "label").show() +{% endhighlight %} +
+
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index a752dacd72d95..a7fa50444209b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -42,8 +42,8 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { /** * :: Experimental :: * Implements the transforms required for fitting a dataset against an R model formula. Currently - * we support a limited subset of the R operators, including '~' and '+'. Also see the R formula - * docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html + * we support a limited subset of the R operators, including '.', '~', '+', and '-'. Also see the + * R formula docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html */ @Experimental class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase { From ba5f7e1842f2c5852b5309910c0d39926643da69 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 20 Aug 2015 08:13:25 +0800 Subject: [PATCH 012/802] [SPARK-10035] [SQL] Parquet filters does not process EqualNullSafe filter. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit As I talked with Lian, 1. I added EquelNullSafe to ParquetFilters - It uses the same equality comparison filter with EqualTo since the Parquet filter performs actually null-safe equality comparison. 2. Updated the test code (ParquetFilterSuite) - Convert catalyst.Expression to sources.Filter - Removed Cast since only Literal is picked up as a proper Filter in DataSourceStrategy - Added EquelNullSafe comparison 3. Removed deprecated createFilter for catalyst.Expression Author: hyukjinkwon Author: 권혁진 Closes #8275 from HyukjinKwon/master. --- .../datasources/parquet/ParquetFilters.scala | 113 +++--------------- .../parquet/ParquetFilterSuite.scala | 63 ++++------ 2 files changed, 37 insertions(+), 139 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 63915e0a28655..c74c8388632f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -22,8 +22,6 @@ import java.nio.ByteBuffer import com.google.common.io.BaseEncoding import org.apache.hadoop.conf.Configuration -import org.apache.parquet.filter2.compat.FilterCompat -import org.apache.parquet.filter2.compat.FilterCompat._ import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.filter2.predicate._ import org.apache.parquet.io.api.Binary @@ -39,12 +37,6 @@ import org.apache.spark.unsafe.types.UTF8String private[sql] object ParquetFilters { val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter" - def createRecordFilter(filterExpressions: Seq[Expression]): Option[Filter] = { - filterExpressions.flatMap { filter => - createFilter(filter) - }.reduceOption(FilterApi.and).map(FilterCompat.get) - } - case class SetInFilter[T <: Comparable[T]]( valueSet: Set[T]) extends UserDefinedPredicate[T] with Serializable { @@ -205,6 +197,16 @@ private[sql] object ParquetFilters { // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, // which can be casted to `false` implicitly. Please refer to the `eval` method of these // operators and the `SimplifyFilters` rule for details. + + // Hyukjin: + // I added [[EqualNullSafe]] with [[org.apache.parquet.filter2.predicate.Operators.Eq]]. + // So, it performs equality comparison identically when given [[sources.Filter]] is [[EqualTo]]. + // The reason why I did this is, that the actual Parquet filter checks null-safe equality + // comparison. + // So I added this and maybe [[EqualTo]] should be changed. It still seems fine though, because + // physical planning does not set `NULL` to [[EqualTo]] but changes it to [[IsNull]] and etc. + // Probably I missed something and obviously this should be changed. + predicate match { case sources.IsNull(name) => makeEq.lift(dataTypeOf(name)).map(_(name, null)) @@ -216,6 +218,11 @@ private[sql] object ParquetFilters { case sources.Not(sources.EqualTo(name, value)) => makeNotEq.lift(dataTypeOf(name)).map(_(name, value)) + case sources.EqualNullSafe(name, value) => + makeEq.lift(dataTypeOf(name)).map(_(name, value)) + case sources.Not(sources.EqualNullSafe(name, value)) => + makeNotEq.lift(dataTypeOf(name)).map(_(name, value)) + case sources.LessThan(name, value) => makeLt.lift(dataTypeOf(name)).map(_(name, value)) case sources.LessThanOrEqual(name, value) => @@ -273,96 +280,6 @@ private[sql] object ParquetFilters { addMethod.invoke(null, classOf[Binary], enumTypeDescriptor) } - /** - * Converts Catalyst predicate expressions to Parquet filter predicates. - * - * @todo This can be removed once we get rid of the old Parquet support. - */ - def createFilter(predicate: Expression): Option[FilterPredicate] = { - // NOTE: - // - // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, - // which can be casted to `false` implicitly. Please refer to the `eval` method of these - // operators and the `SimplifyFilters` rule for details. - predicate match { - case IsNull(NamedExpression(name, dataType)) => - makeEq.lift(dataType).map(_(name, null)) - case IsNotNull(NamedExpression(name, dataType)) => - makeNotEq.lift(dataType).map(_(name, null)) - - case EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType)) => - makeEq.lift(dataType).map(_(name, value)) - case EqualTo(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => - makeEq.lift(dataType).map(_(name, value)) - case EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _)) => - makeEq.lift(dataType).map(_(name, value)) - case EqualTo(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => - makeEq.lift(dataType).map(_(name, value)) - - case Not(EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType))) => - makeNotEq.lift(dataType).map(_(name, value)) - case Not(EqualTo(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _))) => - makeNotEq.lift(dataType).map(_(name, value)) - case Not(EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _))) => - makeNotEq.lift(dataType).map(_(name, value)) - case Not(EqualTo(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType))) => - makeNotEq.lift(dataType).map(_(name, value)) - - case LessThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) => - makeLt.lift(dataType).map(_(name, value)) - case LessThan(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => - makeLt.lift(dataType).map(_(name, value)) - case LessThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) => - makeGt.lift(dataType).map(_(name, value)) - case LessThan(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => - makeGt.lift(dataType).map(_(name, value)) - - case LessThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) => - makeLtEq.lift(dataType).map(_(name, value)) - case LessThanOrEqual(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => - makeLtEq.lift(dataType).map(_(name, value)) - case LessThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) => - makeGtEq.lift(dataType).map(_(name, value)) - case LessThanOrEqual(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => - makeGtEq.lift(dataType).map(_(name, value)) - - case GreaterThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) => - makeGt.lift(dataType).map(_(name, value)) - case GreaterThan(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => - makeGt.lift(dataType).map(_(name, value)) - case GreaterThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) => - makeLt.lift(dataType).map(_(name, value)) - case GreaterThan(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => - makeLt.lift(dataType).map(_(name, value)) - - case GreaterThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) => - makeGtEq.lift(dataType).map(_(name, value)) - case GreaterThanOrEqual(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => - makeGtEq.lift(dataType).map(_(name, value)) - case GreaterThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) => - makeLtEq.lift(dataType).map(_(name, value)) - case GreaterThanOrEqual(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => - makeLtEq.lift(dataType).map(_(name, value)) - - case And(lhs, rhs) => - (createFilter(lhs) ++ createFilter(rhs)).reduceOption(FilterApi.and) - - case Or(lhs, rhs) => - for { - lhsFilter <- createFilter(lhs) - rhsFilter <- createFilter(rhs) - } yield FilterApi.or(lhsFilter, rhsFilter) - - case Not(pred) => - createFilter(pred).map(FilterApi.not) - - case InSet(NamedExpression(name, dataType), valueSet) => - makeInSet.lift(dataType).map(_(name, valueSet)) - - case _ => None - } - } - /** * Note: Inside the Hadoop API we only have access to `Configuration`, not to * [[org.apache.spark.SparkContext]], so we cannot use broadcasts to convey diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 5b4e568bb9838..f067112cfca95 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -24,9 +24,8 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelation} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types._ /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -55,20 +54,22 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex .select(output.map(e => Column(e)): _*) .where(Column(predicate)) - val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { + val analyzedPredicate = query.queryExecution.optimizedPlan.collect { case PhysicalOperation(_, filters, LogicalRelation(_: ParquetRelation)) => filters - }.flatten.reduceOption(_ && _) + }.flatten + assert(analyzedPredicate.nonEmpty) - assert(maybeAnalyzedPredicate.isDefined) - maybeAnalyzedPredicate.foreach { pred => - val maybeFilter = ParquetFilters.createFilter(pred) + val selectedFilters = DataSourceStrategy.selectFilters(analyzedPredicate) + assert(selectedFilters.nonEmpty) + + selectedFilters.foreach { pred => + val maybeFilter = ParquetFilters.createFilter(df.schema, pred) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") maybeFilter.foreach { f => // Doesn't bother checking type parameters here (e.g. `Eq[Integer]`) assert(f.getClass === filterClass) } } - checker(query, expected) } } @@ -109,43 +110,18 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false))) checkFilterPredicate('_1 === true, classOf[Eq[_]], true) + checkFilterPredicate('_1 <=> true, classOf[Eq[_]], true) checkFilterPredicate('_1 !== true, classOf[NotEq[_]], false) } } - test("filter pushdown - short") { - withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit df => - checkFilterPredicate(Cast('_1, IntegerType) === 1, classOf[Eq[_]], 1) - checkFilterPredicate( - Cast('_1, IntegerType) !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - - checkFilterPredicate(Cast('_1, IntegerType) < 2, classOf[Lt[_]], 1) - checkFilterPredicate(Cast('_1, IntegerType) > 3, classOf[Gt[_]], 4) - checkFilterPredicate(Cast('_1, IntegerType) <= 1, classOf[LtEq[_]], 1) - checkFilterPredicate(Cast('_1, IntegerType) >= 4, classOf[GtEq[_]], 4) - - checkFilterPredicate(Literal(1) === Cast('_1, IntegerType), classOf[Eq[_]], 1) - checkFilterPredicate(Literal(2) > Cast('_1, IntegerType), classOf[Lt[_]], 1) - checkFilterPredicate(Literal(3) < Cast('_1, IntegerType), classOf[Gt[_]], 4) - checkFilterPredicate(Literal(1) >= Cast('_1, IntegerType), classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= Cast('_1, IntegerType), classOf[GtEq[_]], 4) - - checkFilterPredicate(!(Cast('_1, IntegerType) < 4), classOf[GtEq[_]], 4) - checkFilterPredicate( - Cast('_1, IntegerType) > 2 && Cast('_1, IntegerType) < 4, classOf[Operators.And], 3) - checkFilterPredicate( - Cast('_1, IntegerType) < 2 || Cast('_1, IntegerType) > 3, - classOf[Operators.Or], - Seq(Row(1), Row(4))) - } - } - test("filter pushdown - integer") { withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) @@ -154,13 +130,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } @@ -171,6 +147,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) @@ -179,13 +156,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } @@ -196,6 +173,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) @@ -204,13 +182,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } @@ -221,6 +199,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 <=> 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) @@ -229,13 +208,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(1) <=> '_1, classOf[Eq[_]], 1) checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) - checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } } @@ -247,6 +226,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1") + checkFilterPredicate('_1 <=> "1", classOf[Eq[_]], "1") checkFilterPredicate( '_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) @@ -256,13 +236,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4") checkFilterPredicate(Literal("1") === '_1, classOf[Eq[_]], "1") + checkFilterPredicate(Literal("1") <=> '_1, classOf[Eq[_]], "1") checkFilterPredicate(Literal("2") > '_1, classOf[Lt[_]], "1") checkFilterPredicate(Literal("3") < '_1, classOf[Gt[_]], "4") checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") - checkFilterPredicate('_1 > "2" && '_1 < "4", classOf[Operators.And], "3") checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) } } @@ -274,6 +254,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex withParquetDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df => checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq[_]], 1.b) + checkBinaryFilterPredicate('_1 <=> 1.b, classOf[Eq[_]], 1.b) checkBinaryFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkBinaryFilterPredicate( @@ -288,13 +269,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b) checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq[_]], 1.b) + checkBinaryFilterPredicate(Literal(1.b) <=> '_1, classOf[Eq[_]], 1.b) checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt[_]], 1.b) checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt[_]], 4.b) checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b) - checkBinaryFilterPredicate('_1 > 2.b && '_1 < 4.b, classOf[Operators.And], 3.b) checkBinaryFilterPredicate( '_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b))) } From 2f2686a73f5a2a53ca5b1023e0d7e0e6c9be5896 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 19 Aug 2015 17:35:41 -0700 Subject: [PATCH 013/802] [SPARK-9242] [SQL] Audit UDAF interface. A few minor changes: 1. Improved documentation 2. Rename apply(distinct....) to distinct. 3. Changed MutableAggregationBuffer from a trait to an abstract class. 4. Renamed returnDataType to dataType to be more consistent with other expressions. And unrelated to UDAFs: 1. Renamed file names in expressions to use suffix "Expressions" to be more consistent. 2. Moved regexp related expressions out to its own file. 3. Renamed StringComparison => StringPredicate. Author: Reynold Xin Closes #8321 from rxin/SPARK-9242. --- ...bitwise.scala => bitwiseExpressions.scala} | 0 ...als.scala => conditionalExpressions.scala} | 0 ...ctions.scala => datetimeExpressions.scala} | 0 ...nctions.scala => decimalExpressions.scala} | 0 ...nFunctions.scala => jsonExpressions.scala} | 0 .../{math.scala => mathExpressions.scala} | 0 ...lFunctions.scala => nullExpressions.scala} | 0 .../{random.scala => randomExpressions.scala} | 0 .../expressions/regexpExpressions.scala | 346 ++++++++++++++++++ ...erations.scala => stringExpressions.scala} | 332 +---------------- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../expressions/StringExpressionsSuite.scala | 2 +- .../apache/spark/sql/UDFRegistration.scala | 1 + .../spark/sql/execution/aggregate/udaf.scala | 2 +- .../apache/spark/sql/expressions/udaf.scala | 44 ++- .../spark/sql/hive/JavaDataFrameSuite.java | 2 +- .../spark/sql/hive/aggregate/MyDoubleAvg.java | 2 +- .../spark/sql/hive/aggregate/MyDoubleSum.java | 2 +- 18 files changed, 386 insertions(+), 349 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{bitwise.scala => bitwiseExpressions.scala} (100%) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{conditionals.scala => conditionalExpressions.scala} (100%) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{datetimeFunctions.scala => datetimeExpressions.scala} (100%) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{decimalFunctions.scala => decimalExpressions.scala} (100%) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{jsonFunctions.scala => jsonExpressions.scala} (100%) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{math.scala => mathExpressions.scala} (100%) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{nullFunctions.scala => nullExpressions.scala} (100%) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{random.scala => randomExpressions.scala} (100%) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{stringOperations.scala => stringExpressions.scala} (74%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonFunctions.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala new file mode 100644 index 0000000000000..6dff28a7cde46 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -0,0 +1,346 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.util.regex.{MatchResult, Pattern} + +import org.apache.commons.lang3.StringEscapeUtils + +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.StringUtils +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + + +trait StringRegexExpression extends ImplicitCastInputTypes { + self: BinaryExpression => + + def escape(v: String): String + def matches(regex: Pattern, str: String): Boolean + + override def dataType: DataType = BooleanType + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + + // try cache the pattern for Literal + private lazy val cache: Pattern = right match { + case x @ Literal(value: String, StringType) => compile(value) + case _ => null + } + + protected def compile(str: String): Pattern = if (str == null) { + null + } else { + // Let it raise exception if couldn't compile the regex string + Pattern.compile(escape(str)) + } + + protected def pattern(str: String) = if (cache == null) compile(str) else cache + + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + val regex = pattern(input2.asInstanceOf[UTF8String].toString) + if(regex == null) { + null + } else { + matches(regex, input1.asInstanceOf[UTF8String].toString) + } + } +} + + +/** + * Simple RegEx pattern matching function + */ +case class Like(left: Expression, right: Expression) + extends BinaryExpression with StringRegexExpression with CodegenFallback { + + override def escape(v: String): String = StringUtils.escapeLikeRegex(v) + + override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() + + override def toString: String = s"$left LIKE $right" + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val patternClass = classOf[Pattern].getName + val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" + val pattern = ctx.freshName("pattern") + + if (right.foldable) { + val rVal = right.eval() + if (rVal != null) { + val regexStr = + StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) + ctx.addMutableState(patternClass, pattern, + s"""$pattern = ${patternClass}.compile("$regexStr");""") + + // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. + val eval = left.gen(ctx) + s""" + ${eval.code} + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = $pattern.matcher(${eval.primitive}.toString()).matches(); + } + """ + } else { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } + } else { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + s""" + String rightStr = ${eval2}.toString(); + ${patternClass} $pattern = ${patternClass}.compile($escapeFunc(rightStr)); + ${ev.primitive} = $pattern.matcher(${eval1}.toString()).matches(); + """ + }) + } + } +} + + +case class RLike(left: Expression, right: Expression) + extends BinaryExpression with StringRegexExpression with CodegenFallback { + + override def escape(v: String): String = v + override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) + override def toString: String = s"$left RLIKE $right" + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val patternClass = classOf[Pattern].getName + val pattern = ctx.freshName("pattern") + + if (right.foldable) { + val rVal = right.eval() + if (rVal != null) { + val regexStr = + StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) + ctx.addMutableState(patternClass, pattern, + s"""$pattern = ${patternClass}.compile("$regexStr");""") + + // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. + val eval = left.gen(ctx) + s""" + ${eval.code} + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = $pattern.matcher(${eval.primitive}.toString()).find(0); + } + """ + } else { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } + } else { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + s""" + String rightStr = ${eval2}.toString(); + ${patternClass} $pattern = ${patternClass}.compile(rightStr); + ${ev.primitive} = $pattern.matcher(${eval1}.toString()).find(0); + """ + }) + } + } +} + + +/** + * Splits str around pat (pattern is a regular expression). + */ +case class StringSplit(str: Expression, pattern: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = str + override def right: Expression = pattern + override def dataType: DataType = ArrayType(StringType) + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + + override def nullSafeEval(string: Any, regex: Any): Any = { + val strings = string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1) + new GenericArrayData(strings.asInstanceOf[Array[Any]]) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val arrayClass = classOf[GenericArrayData].getName + nullSafeCodeGen(ctx, ev, (str, pattern) => + // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. + s"""${ev.primitive} = new $arrayClass($str.split($pattern, -1));""") + } + + override def prettyName: String = "split" +} + + +/** + * Replace all substrings of str that match regexp with rep. + * + * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. + */ +case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + // last regex in string, we will update the pattern iff regexp value changed. + @transient private var lastRegex: UTF8String = _ + // last regex pattern, we cache it for performance concern + @transient private var pattern: Pattern = _ + // last replacement string, we don't want to convert a UTF8String => java.langString every time. + @transient private var lastReplacement: String = _ + @transient private var lastReplacementInUTF8: UTF8String = _ + // result buffer write by Matcher + @transient private val result: StringBuffer = new StringBuffer + + override def nullSafeEval(s: Any, p: Any, r: Any): Any = { + if (!p.equals(lastRegex)) { + // regex value changed + lastRegex = p.asInstanceOf[UTF8String].clone() + pattern = Pattern.compile(lastRegex.toString) + } + if (!r.equals(lastReplacementInUTF8)) { + // replacement string changed + lastReplacementInUTF8 = r.asInstanceOf[UTF8String].clone() + lastReplacement = lastReplacementInUTF8.toString + } + val m = pattern.matcher(s.toString()) + result.delete(0, result.length()) + + while (m.find) { + m.appendReplacement(result, lastReplacement) + } + m.appendTail(result) + + UTF8String.fromString(result.toString) + } + + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType) + override def children: Seq[Expression] = subject :: regexp :: rep :: Nil + override def prettyName: String = "regexp_replace" + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val termLastRegex = ctx.freshName("lastRegex") + val termPattern = ctx.freshName("pattern") + + val termLastReplacement = ctx.freshName("lastReplacement") + val termLastReplacementInUTF8 = ctx.freshName("lastReplacementInUTF8") + + val termResult = ctx.freshName("result") + + val classNamePattern = classOf[Pattern].getCanonicalName + val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName + + ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") + ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") + ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;") + ctx.addMutableState("UTF8String", + termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;") + ctx.addMutableState(classNameStringBuffer, + termResult, s"${termResult} = new $classNameStringBuffer();") + + nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => { + s""" + if (!$regexp.equals(${termLastRegex})) { + // regex value changed + ${termLastRegex} = $regexp.clone(); + ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); + } + if (!$rep.equals(${termLastReplacementInUTF8})) { + // replacement string changed + ${termLastReplacementInUTF8} = $rep.clone(); + ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); + } + ${termResult}.delete(0, ${termResult}.length()); + java.util.regex.Matcher m = ${termPattern}.matcher($subject.toString()); + + while (m.find()) { + m.appendReplacement(${termResult}, ${termLastReplacement}); + } + m.appendTail(${termResult}); + ${ev.primitive} = UTF8String.fromString(${termResult}.toString()); + ${ev.isNull} = false; + """ + }) + } +} + +/** + * Extract a specific(idx) group identified by a Java regex. + * + * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. + */ +case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + def this(s: Expression, r: Expression) = this(s, r, Literal(1)) + + // last regex in string, we will update the pattern iff regexp value changed. + @transient private var lastRegex: UTF8String = _ + // last regex pattern, we cache it for performance concern + @transient private var pattern: Pattern = _ + + override def nullSafeEval(s: Any, p: Any, r: Any): Any = { + if (!p.equals(lastRegex)) { + // regex value changed + lastRegex = p.asInstanceOf[UTF8String].clone() + pattern = Pattern.compile(lastRegex.toString) + } + val m = pattern.matcher(s.toString) + if (m.find) { + val mr: MatchResult = m.toMatchResult + UTF8String.fromString(mr.group(r.asInstanceOf[Int])) + } else { + UTF8String.EMPTY_UTF8 + } + } + + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType) + override def children: Seq[Expression] = subject :: regexp :: idx :: Nil + override def prettyName: String = "regexp_extract" + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val termLastRegex = ctx.freshName("lastRegex") + val termPattern = ctx.freshName("pattern") + val classNamePattern = classOf[Pattern].getCanonicalName + + ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") + ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") + + nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { + s""" + if (!$regexp.equals(${termLastRegex})) { + // regex value changed + ${termLastRegex} = $regexp.clone(); + ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); + } + java.util.regex.Matcher m = + ${termPattern}.matcher($subject.toString()); + if (m.find()) { + java.util.regex.MatchResult mr = m.toMatchResult(); + ${ev.primitive} = UTF8String.fromString(mr.group($idx)); + ${ev.isNull} = false; + } else { + ${ev.primitive} = UTF8String.EMPTY_UTF8; + ${ev.isNull} = false; + }""" + }) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala similarity index 74% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index ca044d3e95e38..3c23f2ecfb57c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -21,13 +21,9 @@ import java.text.DecimalFormat import java.util.Arrays import java.util.{Map => JMap, HashMap} import java.util.Locale -import java.util.regex.{MatchResult, Pattern} - -import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -124,143 +120,6 @@ case class ConcatWs(children: Seq[Expression]) } } - -trait StringRegexExpression extends ImplicitCastInputTypes { - self: BinaryExpression => - - def escape(v: String): String - def matches(regex: Pattern, str: String): Boolean - - override def dataType: DataType = BooleanType - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) - - // try cache the pattern for Literal - private lazy val cache: Pattern = right match { - case x @ Literal(value: String, StringType) => compile(value) - case _ => null - } - - protected def compile(str: String): Pattern = if (str == null) { - null - } else { - // Let it raise exception if couldn't compile the regex string - Pattern.compile(escape(str)) - } - - protected def pattern(str: String) = if (cache == null) compile(str) else cache - - protected override def nullSafeEval(input1: Any, input2: Any): Any = { - val regex = pattern(input2.asInstanceOf[UTF8String].toString()) - if(regex == null) { - null - } else { - matches(regex, input1.asInstanceOf[UTF8String].toString()) - } - } -} - -/** - * Simple RegEx pattern matching function - */ -case class Like(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression with CodegenFallback { - - override def escape(v: String): String = StringUtils.escapeLikeRegex(v) - - override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() - - override def toString: String = s"$left LIKE $right" - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val patternClass = classOf[Pattern].getName - val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" - val pattern = ctx.freshName("pattern") - - if (right.foldable) { - val rVal = right.eval() - if (rVal != null) { - val regexStr = - StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) - ctx.addMutableState(patternClass, pattern, - s"""$pattern = ${patternClass}.compile("$regexStr");""") - - // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. - val eval = left.gen(ctx) - s""" - ${eval.code} - boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.primitive} = $pattern.matcher(${eval.primitive}.toString()).matches(); - } - """ - } else { - s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - """ - } - } else { - nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - s""" - String rightStr = ${eval2}.toString(); - ${patternClass} $pattern = ${patternClass}.compile($escapeFunc(rightStr)); - ${ev.primitive} = $pattern.matcher(${eval1}.toString()).matches(); - """ - }) - } - } -} - - -case class RLike(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression with CodegenFallback { - - override def escape(v: String): String = v - override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) - override def toString: String = s"$left RLIKE $right" - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val patternClass = classOf[Pattern].getName - val pattern = ctx.freshName("pattern") - - if (right.foldable) { - val rVal = right.eval() - if (rVal != null) { - val regexStr = - StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) - ctx.addMutableState(patternClass, pattern, - s"""$pattern = ${patternClass}.compile("$regexStr");""") - - // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. - val eval = left.gen(ctx) - s""" - ${eval.code} - boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.primitive} = $pattern.matcher(${eval.primitive}.toString()).find(0); - } - """ - } else { - s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - """ - } - } else { - nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - s""" - String rightStr = ${eval2}.toString(); - ${patternClass} $pattern = ${patternClass}.compile(rightStr); - ${ev.primitive} = $pattern.matcher(${eval1}.toString()).find(0); - """ - }) - } - } -} - - trait String2StringExpression extends ImplicitCastInputTypes { self: UnaryExpression => @@ -305,7 +164,7 @@ case class Lower(child: Expression) extends UnaryExpression with String2StringEx } /** A base trait for functions that compare two strings, returning a boolean. */ -trait StringComparison extends ImplicitCastInputTypes { +trait StringPredicate extends Predicate with ImplicitCastInputTypes { self: BinaryExpression => def compare(l: UTF8String, r: UTF8String): Boolean @@ -322,7 +181,7 @@ trait StringComparison extends ImplicitCastInputTypes { * A function that returns true if the string `left` contains the string `right`. */ case class Contains(left: Expression, right: Expression) - extends BinaryExpression with Predicate with StringComparison { + extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") @@ -333,7 +192,7 @@ case class Contains(left: Expression, right: Expression) * A function that returns true if the string `left` starts with the string `right`. */ case class StartsWith(left: Expression, right: Expression) - extends BinaryExpression with Predicate with StringComparison { + extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") @@ -344,7 +203,7 @@ case class StartsWith(left: Expression, right: Expression) * A function that returns true if the string `left` ends with the string `right`. */ case class EndsWith(left: Expression, right: Expression) - extends BinaryExpression with Predicate with StringComparison { + extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") @@ -769,32 +628,6 @@ case class StringSpace(child: Expression) override def prettyName: String = "space" } -/** - * Splits str around pat (pattern is a regular expression). - */ -case class StringSplit(str: Expression, pattern: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - override def left: Expression = str - override def right: Expression = pattern - override def dataType: DataType = ArrayType(StringType) - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) - - override def nullSafeEval(string: Any, regex: Any): Any = { - val strings = string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1) - new GenericArrayData(strings.asInstanceOf[Array[Any]]) - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val arrayClass = classOf[GenericArrayData].getName - nullSafeCodeGen(ctx, ev, (str, pattern) => - // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. - s"""${ev.primitive} = new $arrayClass($str.split($pattern, -1));""") - } - - override def prettyName: String = "split" -} - object Substring { def subStringBinarySQL(bytes: Array[Byte], pos: Int, len: Int): Array[Byte] = { if (pos > bytes.length) { @@ -1048,163 +881,6 @@ case class Encode(value: Expression, charset: Expression) } } -/** - * Replace all substrings of str that match regexp with rep. - * - * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. - */ -case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression) - extends TernaryExpression with ImplicitCastInputTypes { - - // last regex in string, we will update the pattern iff regexp value changed. - @transient private var lastRegex: UTF8String = _ - // last regex pattern, we cache it for performance concern - @transient private var pattern: Pattern = _ - // last replacement string, we don't want to convert a UTF8String => java.langString every time. - @transient private var lastReplacement: String = _ - @transient private var lastReplacementInUTF8: UTF8String = _ - // result buffer write by Matcher - @transient private val result: StringBuffer = new StringBuffer - - override def nullSafeEval(s: Any, p: Any, r: Any): Any = { - if (!p.equals(lastRegex)) { - // regex value changed - lastRegex = p.asInstanceOf[UTF8String].clone() - pattern = Pattern.compile(lastRegex.toString) - } - if (!r.equals(lastReplacementInUTF8)) { - // replacement string changed - lastReplacementInUTF8 = r.asInstanceOf[UTF8String].clone() - lastReplacement = lastReplacementInUTF8.toString - } - val m = pattern.matcher(s.toString()) - result.delete(0, result.length()) - - while (m.find) { - m.appendReplacement(result, lastReplacement) - } - m.appendTail(result) - - UTF8String.fromString(result.toString) - } - - override def dataType: DataType = StringType - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType) - override def children: Seq[Expression] = subject :: regexp :: rep :: Nil - override def prettyName: String = "regexp_replace" - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val termLastRegex = ctx.freshName("lastRegex") - val termPattern = ctx.freshName("pattern") - - val termLastReplacement = ctx.freshName("lastReplacement") - val termLastReplacementInUTF8 = ctx.freshName("lastReplacementInUTF8") - - val termResult = ctx.freshName("result") - - val classNamePattern = classOf[Pattern].getCanonicalName - val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName - - ctx.addMutableState("UTF8String", - termLastRegex, s"${termLastRegex} = null;") - ctx.addMutableState(classNamePattern, - termPattern, s"${termPattern} = null;") - ctx.addMutableState("String", - termLastReplacement, s"${termLastReplacement} = null;") - ctx.addMutableState("UTF8String", - termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;") - ctx.addMutableState(classNameStringBuffer, - termResult, s"${termResult} = new $classNameStringBuffer();") - - nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => { - s""" - if (!$regexp.equals(${termLastRegex})) { - // regex value changed - ${termLastRegex} = $regexp.clone(); - ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); - } - if (!$rep.equals(${termLastReplacementInUTF8})) { - // replacement string changed - ${termLastReplacementInUTF8} = $rep.clone(); - ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); - } - ${termResult}.delete(0, ${termResult}.length()); - java.util.regex.Matcher m = ${termPattern}.matcher($subject.toString()); - - while (m.find()) { - m.appendReplacement(${termResult}, ${termLastReplacement}); - } - m.appendTail(${termResult}); - ${ev.primitive} = UTF8String.fromString(${termResult}.toString()); - ${ev.isNull} = false; - """ - }) - } -} - -/** - * Extract a specific(idx) group identified by a Java regex. - * - * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. - */ -case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) - extends TernaryExpression with ImplicitCastInputTypes { - def this(s: Expression, r: Expression) = this(s, r, Literal(1)) - - // last regex in string, we will update the pattern iff regexp value changed. - @transient private var lastRegex: UTF8String = _ - // last regex pattern, we cache it for performance concern - @transient private var pattern: Pattern = _ - - override def nullSafeEval(s: Any, p: Any, r: Any): Any = { - if (!p.equals(lastRegex)) { - // regex value changed - lastRegex = p.asInstanceOf[UTF8String].clone() - pattern = Pattern.compile(lastRegex.toString) - } - val m = pattern.matcher(s.toString()) - if (m.find) { - val mr: MatchResult = m.toMatchResult - UTF8String.fromString(mr.group(r.asInstanceOf[Int])) - } else { - UTF8String.EMPTY_UTF8 - } - } - - override def dataType: DataType = StringType - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType) - override def children: Seq[Expression] = subject :: regexp :: idx :: Nil - override def prettyName: String = "regexp_extract" - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val termLastRegex = ctx.freshName("lastRegex") - val termPattern = ctx.freshName("pattern") - val classNamePattern = classOf[Pattern].getCanonicalName - - ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") - ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") - - nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { - s""" - if (!$regexp.equals(${termLastRegex})) { - // regex value changed - ${termLastRegex} = $regexp.clone(); - ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); - } - java.util.regex.Matcher m = - ${termPattern}.matcher($subject.toString()); - if (m.find()) { - java.util.regex.MatchResult mr = m.toMatchResult(); - ${ev.primitive} = UTF8String.fromString(mr.group($idx)); - ${ev.isNull} = false; - } else { - ${ev.primitive} = UTF8String.EMPTY_UTF8; - ${ev.isNull} = false; - }""" - }) - } -} - /** * Formats the number X to a format like '#,###,###.##', rounded to D decimal places, * and returns the result as a string. If D is 0, the result has no decimal point or diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 42457d5318b48..854463dd11c74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -372,7 +372,7 @@ object NullPropagation extends Rule[LogicalPlan] { case _ => e } - case e: StringComparison => e.children match { + case e: StringPredicate => e.children match { case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) case _ => e diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 426dc272471ae..99e3b13ce8c97 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -673,7 +673,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Length(Literal.create(null, BinaryType)), null, create_row(bytes)) } - test("number format") { + test("format_number / FormatNumber") { checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal(3)), "4.000") checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal(3)), "4.000") checkEvaluation(FormatNumber(Literal(4.0f), Literal(3)), "4.000") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 1f270560d7bc1..fc4d0938c533a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -56,6 +56,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { /** * Register a user-defined aggregate function (UDAF). + * * @param name the name of the UDAF. * @param udaf the UDAF needs to be registered. * @return the registered UDAF. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 7619f3ec9f0a7..d43d3dd9ffaae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -304,7 +304,7 @@ private[sql] case class ScalaUDAF( override def nullable: Boolean = true - override def dataType: DataType = udaf.returnDataType + override def dataType: DataType = udaf.dataType override def deterministic: Boolean = udaf.deterministic diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index 5180871585f25..258afadc76951 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.expressions -import org.apache.spark.sql.catalyst.expressions.ScalaUDF import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2} import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.{Column, Row} @@ -26,7 +25,7 @@ import org.apache.spark.annotation.Experimental /** * :: Experimental :: - * The abstract class for implementing user-defined aggregate functions. + * The base class for implementing user-defined aggregate functions (UDAF). */ @Experimental abstract class UserDefinedAggregateFunction extends Serializable { @@ -67,22 +66,35 @@ abstract class UserDefinedAggregateFunction extends Serializable { /** * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]]. */ - def returnDataType: DataType + def dataType: DataType - /** Indicates if this function is deterministic. */ + /** + * Returns true iff this function is deterministic, i.e. given the same input, + * always return the same output. + */ def deterministic: Boolean /** - * Initializes the given aggregation buffer. Initial values set by this method should satisfy - * the condition that when merging two buffers with initial values, the new buffer - * still store initial values. + * Initializes the given aggregation buffer, i.e. the zero value of the aggregation buffer. + * + * The contract should be that applying the merge function on two initial buffers should just + * return the initial buffer itself, i.e. + * `merge(initialBuffer, initialBuffer)` should equal `initialBuffer`. */ def initialize(buffer: MutableAggregationBuffer): Unit - /** Updates the given aggregation buffer `buffer` with new input data from `input`. */ + /** + * Updates the given aggregation buffer `buffer` with new input data from `input`. + * + * This is called once per input row. + */ def update(buffer: MutableAggregationBuffer, input: Row): Unit - /** Merges two aggregation buffers and stores the updated buffer values back to `buffer1`. */ + /** + * Merges two aggregation buffers and stores the updated buffer values back to `buffer1`. + * + * This is called when we merge two partially aggregated data together. + */ def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit /** @@ -92,7 +104,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { def evaluate(buffer: Row): Any /** - * Creates a [[Column]] for this UDAF with given [[Column]]s as arguments. + * Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments. */ @scala.annotation.varargs def apply(exprs: Column*): Column = { @@ -105,16 +117,16 @@ abstract class UserDefinedAggregateFunction extends Serializable { } /** - * Creates a [[Column]] for this UDAF with given [[Column]]s as arguments. - * If `isDistinct` is true, this UDAF is working on distinct input values. + * Creates a [[Column]] for this UDAF using the distinct values of the given + * [[Column]]s as input arguments. */ @scala.annotation.varargs - def apply(isDistinct: Boolean, exprs: Column*): Column = { + def distinct(exprs: Column*): Column = { val aggregateExpression = AggregateExpression2( ScalaUDAF(exprs.map(_.expr), this), Complete, - isDistinct = isDistinct) + isDistinct = true) Column(aggregateExpression) } } @@ -122,9 +134,11 @@ abstract class UserDefinedAggregateFunction extends Serializable { /** * :: Experimental :: * A [[Row]] representing an mutable aggregation buffer. + * + * This is not meant to be extended outside of Spark. */ @Experimental -trait MutableAggregationBuffer extends Row { +abstract class MutableAggregationBuffer extends Row { /** Update the ith value of this buffer. */ def update(i: Int, value: Any): Unit diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java index 21b053f07a3ba..a30dfa554eabc 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -92,7 +92,7 @@ public void testUDAF() { DataFrame aggregatedDF = df.groupBy() .agg( - udaf.apply(true, col("value")), + udaf.distinct(col("value")), udaf.apply(col("value")), registeredUDAF.apply(col("value")), callUDF("mydoublesum", col("value"))); diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java index a2247e3da1554..2961b803f14aa 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java @@ -65,7 +65,7 @@ public MyDoubleAvg() { return _bufferSchema; } - @Override public DataType returnDataType() { + @Override public DataType dataType() { return _returnDataType; } diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java index da29e24d267dd..c71882a6e7bed 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java @@ -60,7 +60,7 @@ public MyDoubleSum() { return _bufferSchema; } - @Override public DataType returnDataType() { + @Override public DataType dataType() { return _returnDataType; } From 1f29d502e7ecd6faa185d70dc714f9ea3922fb6d Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 19 Aug 2015 18:36:01 -0700 Subject: [PATCH 014/802] [SPARK-9812] [STREAMING] Fix Python 3 compatibility issue in PySpark Streaming and some docs This PR includes the following fixes: 1. Use `range` instead of `xrange` in `queue_stream.py` to support Python 3. 2. Fix the issue that `utf8_decoder` will return `bytes` rather than `str` when receiving an empty `bytes` in Python 3. 3. Fix the commands in docs so that the user can copy them directly to the command line. The previous commands was broken in the middle of a path, so when copying to the command line, the path would be split to two parts by the extra spaces, which forces the user to fix it manually. Author: zsxwing Closes #8315 from zsxwing/SPARK-9812. --- .../src/main/python/streaming/direct_kafka_wordcount.py | 6 +++--- examples/src/main/python/streaming/flume_wordcount.py | 5 +++-- examples/src/main/python/streaming/kafka_wordcount.py | 5 +++-- examples/src/main/python/streaming/mqtt_wordcount.py | 5 +++-- examples/src/main/python/streaming/queue_stream.py | 4 ++-- python/pyspark/streaming/flume.py | 4 +++- python/pyspark/streaming/kafka.py | 4 +++- python/pyspark/streaming/kinesis.py | 4 +++- 8 files changed, 23 insertions(+), 14 deletions(-) diff --git a/examples/src/main/python/streaming/direct_kafka_wordcount.py b/examples/src/main/python/streaming/direct_kafka_wordcount.py index 6ef188a220c51..ea20678b9acad 100644 --- a/examples/src/main/python/streaming/direct_kafka_wordcount.py +++ b/examples/src/main/python/streaming/direct_kafka_wordcount.py @@ -23,8 +23,8 @@ http://kafka.apache.org/documentation.html#quickstart and then run the example - `$ bin/spark-submit --jars external/kafka-assembly/target/scala-*/\ - spark-streaming-kafka-assembly-*.jar \ + `$ bin/spark-submit --jars \ + external/kafka-assembly/target/scala-*/spark-streaming-kafka-assembly-*.jar \ examples/src/main/python/streaming/direct_kafka_wordcount.py \ localhost:9092 test` """ @@ -37,7 +37,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: direct_kafka_wordcount.py " + print("Usage: direct_kafka_wordcount.py ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonStreamingDirectKafkaWordCount") diff --git a/examples/src/main/python/streaming/flume_wordcount.py b/examples/src/main/python/streaming/flume_wordcount.py index 091b64d8c4af4..d75bc6daac138 100644 --- a/examples/src/main/python/streaming/flume_wordcount.py +++ b/examples/src/main/python/streaming/flume_wordcount.py @@ -23,8 +23,9 @@ https://flume.apache.org/documentation.html and then run the example - `$ bin/spark-submit --jars external/flume-assembly/target/scala-*/\ - spark-streaming-flume-assembly-*.jar examples/src/main/python/streaming/flume_wordcount.py \ + `$ bin/spark-submit --jars \ + external/flume-assembly/target/scala-*/spark-streaming-flume-assembly-*.jar \ + examples/src/main/python/streaming/flume_wordcount.py \ localhost 12345 """ from __future__ import print_function diff --git a/examples/src/main/python/streaming/kafka_wordcount.py b/examples/src/main/python/streaming/kafka_wordcount.py index b178e7899b5e1..8d697f620f467 100644 --- a/examples/src/main/python/streaming/kafka_wordcount.py +++ b/examples/src/main/python/streaming/kafka_wordcount.py @@ -23,8 +23,9 @@ http://kafka.apache.org/documentation.html#quickstart and then run the example - `$ bin/spark-submit --jars external/kafka-assembly/target/scala-*/\ - spark-streaming-kafka-assembly-*.jar examples/src/main/python/streaming/kafka_wordcount.py \ + `$ bin/spark-submit --jars \ + external/kafka-assembly/target/scala-*/spark-streaming-kafka-assembly-*.jar \ + examples/src/main/python/streaming/kafka_wordcount.py \ localhost:2181 test` """ from __future__ import print_function diff --git a/examples/src/main/python/streaming/mqtt_wordcount.py b/examples/src/main/python/streaming/mqtt_wordcount.py index 617ce5ea6775e..abf9c0e21d307 100644 --- a/examples/src/main/python/streaming/mqtt_wordcount.py +++ b/examples/src/main/python/streaming/mqtt_wordcount.py @@ -26,8 +26,9 @@ http://www.eclipse.org/paho/#getting-started and then run the example - `$ bin/spark-submit --jars external/mqtt-assembly/target/scala-*/\ - spark-streaming-mqtt-assembly-*.jar examples/src/main/python/streaming/mqtt_wordcount.py \ + `$ bin/spark-submit --jars \ + external/mqtt-assembly/target/scala-*/spark-streaming-mqtt-assembly-*.jar \ + examples/src/main/python/streaming/mqtt_wordcount.py \ tcp://localhost:1883 foo` """ diff --git a/examples/src/main/python/streaming/queue_stream.py b/examples/src/main/python/streaming/queue_stream.py index dcd6a0fc6ff91..b3808907f74a6 100644 --- a/examples/src/main/python/streaming/queue_stream.py +++ b/examples/src/main/python/streaming/queue_stream.py @@ -36,8 +36,8 @@ # Create the queue through which RDDs can be pushed to # a QueueInputDStream rddQueue = [] - for i in xrange(5): - rddQueue += [ssc.sparkContext.parallelize([j for j in xrange(1, 1001)], 10)] + for i in range(5): + rddQueue += [ssc.sparkContext.parallelize([j for j in range(1, 1001)], 10)] # Create the QueueInputDStream and use it do some processing inputStream = ssc.queueStream(rddQueue) diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py index cbb573f226bbe..c0cdc50d8d423 100644 --- a/python/pyspark/streaming/flume.py +++ b/python/pyspark/streaming/flume.py @@ -31,7 +31,9 @@ def utf8_decoder(s): """ Decode the unicode as UTF-8 """ - return s and s.decode('utf-8') + if s is None: + return None + return s.decode('utf-8') class FlumeUtils(object): diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index dc5b7fd878aef..8a814c64c0423 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -29,7 +29,9 @@ def utf8_decoder(s): """ Decode the unicode as UTF-8 """ - return s and s.decode('utf-8') + if s is None: + return None + return s.decode('utf-8') class KafkaUtils(object): diff --git a/python/pyspark/streaming/kinesis.py b/python/pyspark/streaming/kinesis.py index bcfe2703fecf9..34be5880e1708 100644 --- a/python/pyspark/streaming/kinesis.py +++ b/python/pyspark/streaming/kinesis.py @@ -26,7 +26,9 @@ def utf8_decoder(s): """ Decode the unicode as UTF-8 """ - return s and s.decode('utf-8') + if s is None: + return None + return s.decode('utf-8') class KinesisUtils(object): From affc8a887ede9fdc2ca6051833954cd10918c869 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 19 Aug 2015 19:43:09 -0700 Subject: [PATCH 015/802] [SPARK-10125] [STREAMING] Fix a potential deadlock in JobGenerator.stop Because `lazy val` uses `this` lock, if JobGenerator.stop and JobGenerator.doCheckpoint (JobGenerator.shouldCheckpoint has not yet been initialized) run at the same time, it may hang. Here are the stack traces for the deadlock: ```Java "pool-1-thread-1-ScalaTest-running-StreamingListenerSuite" #11 prio=5 os_prio=31 tid=0x00007fd35d094800 nid=0x5703 in Object.wait() [0x000000012ecaf000] java.lang.Thread.State: WAITING (on object monitor) at java.lang.Object.wait(Native Method) at java.lang.Thread.join(Thread.java:1245) - locked <0x00000007b5d8d7f8> (a org.apache.spark.util.EventLoop$$anon$1) at java.lang.Thread.join(Thread.java:1319) at org.apache.spark.util.EventLoop.stop(EventLoop.scala:81) at org.apache.spark.streaming.scheduler.JobGenerator.stop(JobGenerator.scala:155) - locked <0x00000007b5d8cea0> (a org.apache.spark.streaming.scheduler.JobGenerator) at org.apache.spark.streaming.scheduler.JobScheduler.stop(JobScheduler.scala:95) - locked <0x00000007b5d8ced8> (a org.apache.spark.streaming.scheduler.JobScheduler) at org.apache.spark.streaming.StreamingContext.stop(StreamingContext.scala:687) "JobGenerator" #67 daemon prio=5 os_prio=31 tid=0x00007fd35c3b9800 nid=0x9f03 waiting for monitor entry [0x0000000139e4a000] java.lang.Thread.State: BLOCKED (on object monitor) at org.apache.spark.streaming.scheduler.JobGenerator.shouldCheckpoint$lzycompute(JobGenerator.scala:63) - waiting to lock <0x00000007b5d8cea0> (a org.apache.spark.streaming.scheduler.JobGenerator) at org.apache.spark.streaming.scheduler.JobGenerator.shouldCheckpoint(JobGenerator.scala:63) at org.apache.spark.streaming.scheduler.JobGenerator.doCheckpoint(JobGenerator.scala:290) at org.apache.spark.streaming.scheduler.JobGenerator.org$apache$spark$streaming$scheduler$JobGenerator$$processEvent(JobGenerator.scala:182) at org.apache.spark.streaming.scheduler.JobGenerator$$anon$1.onReceive(JobGenerator.scala:83) at org.apache.spark.streaming.scheduler.JobGenerator$$anon$1.onReceive(JobGenerator.scala:82) at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48) ``` I can use this patch to produce this deadlock: https://github.com/zsxwing/spark/commit/8a88f28d1331003a65fabef48ae3d22a7c21f05f And a timeout build in Jenkins due to this deadlock: https://amplab.cs.berkeley.edu/jenkins/job/NewSparkPullRequestBuilder/1654/ This PR initializes `checkpointWriter` before `eventLoop` uses it to avoid this deadlock. Author: zsxwing Closes #8326 from zsxwing/SPARK-10125. --- .../org/apache/spark/streaming/scheduler/JobGenerator.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 9f2117ada61c0..2de035d166e7b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -79,6 +79,10 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { def start(): Unit = synchronized { if (eventLoop != null) return // generator has already been started + // Call checkpointWriter here to initialize it before eventLoop uses it to avoid a deadlock. + // See SPARK-10125 + checkpointWriter + eventLoop = new EventLoop[JobGeneratorEvent]("JobGenerator") { override protected def onReceive(event: JobGeneratorEvent): Unit = processEvent(event) From 73431d8afb41b93888d2642a1ce2d011f03fb740 Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Wed, 19 Aug 2015 19:43:26 -0700 Subject: [PATCH 016/802] [SPARK-10124] [MESOS] Fix removing queued driver in mesos cluster mode. Currently the spark applications can be queued to the Mesos cluster dispatcher, but when multiple jobs are in queue we don't handle removing jobs from the buffer correctly while iterating and causes null pointer exception. This patch copies the buffer before iterating them, so exceptions aren't thrown when the jobs are removed. Author: Timothy Chen Closes #8322 from tnachen/fix_cluster_mode. --- .../cluster/mesos/MesosClusterScheduler.scala | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 64ec2b8e3db15..1206f184fbc82 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -507,14 +507,16 @@ private[spark] class MesosClusterScheduler( val driversToRetry = pendingRetryDrivers.filter { d => d.retryState.get.nextRetry.before(currentTime) } + scheduleTasks( - driversToRetry, + copyBuffer(driversToRetry), removeFromPendingRetryDrivers, currentOffers, tasks) + // Then we walk through the queued drivers and try to schedule them. scheduleTasks( - queuedDrivers, + copyBuffer(queuedDrivers), removeFromQueuedDrivers, currentOffers, tasks) @@ -527,13 +529,14 @@ private[spark] class MesosClusterScheduler( .foreach(o => driver.declineOffer(o.getId)) } + private def copyBuffer( + buffer: ArrayBuffer[MesosDriverDescription]): ArrayBuffer[MesosDriverDescription] = { + val newBuffer = new ArrayBuffer[MesosDriverDescription](buffer.size) + buffer.copyToBuffer(newBuffer) + newBuffer + } + def getSchedulerState(): MesosClusterSchedulerState = { - def copyBuffer( - buffer: ArrayBuffer[MesosDriverDescription]): ArrayBuffer[MesosDriverDescription] = { - val newBuffer = new ArrayBuffer[MesosDriverDescription](buffer.size) - buffer.copyToBuffer(newBuffer) - newBuffer - } stateLock.synchronized { new MesosClusterSchedulerState( frameworkId, From b762f9920f7587d3c08493c49dd2fede62110b88 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 19 Aug 2015 21:15:58 -0700 Subject: [PATCH 017/802] [SPARK-10128] [STREAMING] Used correct classloader to deserialize WAL data Recovering Kinesis sequence numbers from WAL leads to classnotfoundexception because the ObjectInputStream does not use the correct classloader and the SequenceNumberRanges class (in streaming-kinesis-asl package) cannot be found (added through spark-submit) while deserializing. The solution is to use `Thread.currentThread().getContextClassLoader` while deserializing. Author: Tathagata Das Closes #8328 from tdas/SPARK-10128 and squashes the following commits: f19b1c2 [Tathagata Das] Used correct classloader to deserialize WAL data --- .../spark/streaming/scheduler/ReceivedBlockTracker.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index 7720259a5d794..53b96d51c9180 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.streaming.Time import org.apache.spark.streaming.util.{WriteAheadLog, WriteAheadLogUtils} import org.apache.spark.util.{Clock, Utils} -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.{Logging, SparkConf} /** Trait representing any event in the ReceivedBlockTracker that updates its state. */ private[streaming] sealed trait ReceivedBlockTrackerLogEvent @@ -199,7 +199,8 @@ private[streaming] class ReceivedBlockTracker( import scala.collection.JavaConversions._ writeAheadLog.readAll().foreach { byteBuffer => logTrace("Recovering record " + byteBuffer) - Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) match { + Utils.deserialize[ReceivedBlockTrackerLogEvent]( + byteBuffer.array, Thread.currentThread().getContextClassLoader) match { case BlockAdditionEvent(receivedBlockInfo) => insertAddedBlock(receivedBlockInfo) case BatchAllocationEvent(time, allocatedBlocks) => From 43e0135421b2262cbb0e06aae53523f663b4f959 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 20 Aug 2015 15:30:31 +0800 Subject: [PATCH 018/802] [SPARK-10092] [SQL] Multi-DB support follow up. https://issues.apache.org/jira/browse/SPARK-10092 This pr is a follow-up one for Multi-DB support. It has the following changes: * `HiveContext.refreshTable` now accepts `dbName.tableName`. * `HiveContext.analyze` now accepts `dbName.tableName`. * `CreateTableUsing`, `CreateTableUsingAsSelect`, `CreateTempTableUsing`, `CreateTempTableUsingAsSelect`, `CreateMetastoreDataSource`, and `CreateMetastoreDataSourceAsSelect` all take `TableIdentifier` instead of the string representation of table name. * When you call `saveAsTable` with a specified database, the data will be saved to the correct location. * Explicitly do not allow users to create a temporary with a specified database name (users cannot do it before). * When we save table to metastore, we also check if db name and table name can be accepted by hive (using `MetaStoreUtils.validateName`). Author: Yin Huai Closes #8324 from yhuai/saveAsTableDB. --- .../spark/sql/catalyst/TableIdentifier.scala | 4 +- .../spark/sql/catalyst/analysis/Catalog.scala | 63 +++++-- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../org/apache/spark/sql/SQLContext.scala | 15 +- .../spark/sql/execution/SparkStrategies.scala | 10 +- .../sql/execution/datasources/DDLParser.scala | 32 ++-- .../spark/sql/execution/datasources/ddl.scala | 22 +-- .../sql/execution/datasources/rules.scala | 8 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 35 ++++ .../apache/spark/sql/hive/HiveContext.scala | 14 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 22 ++- .../spark/sql/hive/HiveStrategies.scala | 12 +- .../spark/sql/hive/execution/commands.scala | 54 ++++-- .../spark/sql/hive/ListTablesSuite.scala | 6 - .../spark/sql/hive/MultiDatabaseSuite.scala | 158 +++++++++++++++++- .../sql/hive/execution/SQLQuerySuite.scala | 35 ++++ 16 files changed, 398 insertions(+), 94 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala index aebcdeb9d070f..d701559bf2d9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala @@ -25,7 +25,9 @@ private[sql] case class TableIdentifier(table: String, database: Option[String] def toSeq: Seq[String] = database.toSeq :+ table - override def toString: String = toSeq.map("`" + _ + "`").mkString(".") + override def toString: String = quotedString + + def quotedString: String = toSeq.map("`" + _ + "`").mkString(".") def unquotedString: String = toSeq.mkString(".") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 5766e6a2dd51a..503c4f4b20f38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -23,6 +23,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{TableIdentifier, CatalystConf, EmptyConf} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery} @@ -55,12 +56,15 @@ trait Catalog { def refreshTable(tableIdent: TableIdentifier): Unit + // TODO: Refactor it in the work of SPARK-10104 def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit + // TODO: Refactor it in the work of SPARK-10104 def unregisterTable(tableIdentifier: Seq[String]): Unit def unregisterAllTables(): Unit + // TODO: Refactor it in the work of SPARK-10104 protected def processTableIdentifier(tableIdentifier: Seq[String]): Seq[String] = { if (conf.caseSensitiveAnalysis) { tableIdentifier @@ -69,6 +73,7 @@ trait Catalog { } } + // TODO: Refactor it in the work of SPARK-10104 protected def getDbTableName(tableIdent: Seq[String]): String = { val size = tableIdent.size if (size <= 2) { @@ -78,9 +83,22 @@ trait Catalog { } } + // TODO: Refactor it in the work of SPARK-10104 protected def getDBTable(tableIdent: Seq[String]) : (Option[String], String) = { (tableIdent.lift(tableIdent.size - 2), tableIdent.last) } + + /** + * It is not allowed to specifiy database name for tables stored in [[SimpleCatalog]]. + * We use this method to check it. + */ + protected def checkTableIdentifier(tableIdentifier: Seq[String]): Unit = { + if (tableIdentifier.length > 1) { + throw new AnalysisException("Specifying database name or other qualifiers are not allowed " + + "for temporary tables. If the table name has dots (.) in it, please quote the " + + "table name with backticks (`).") + } + } } class SimpleCatalog(val conf: CatalystConf) extends Catalog { @@ -89,11 +107,13 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { override def registerTable( tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { + checkTableIdentifier(tableIdentifier) val tableIdent = processTableIdentifier(tableIdentifier) tables.put(getDbTableName(tableIdent), plan) } override def unregisterTable(tableIdentifier: Seq[String]): Unit = { + checkTableIdentifier(tableIdentifier) val tableIdent = processTableIdentifier(tableIdentifier) tables.remove(getDbTableName(tableIdent)) } @@ -103,6 +123,7 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { } override def tableExists(tableIdentifier: Seq[String]): Boolean = { + checkTableIdentifier(tableIdentifier) val tableIdent = processTableIdentifier(tableIdentifier) tables.containsKey(getDbTableName(tableIdent)) } @@ -110,6 +131,7 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { override def lookupRelation( tableIdentifier: Seq[String], alias: Option[String] = None): LogicalPlan = { + checkTableIdentifier(tableIdentifier) val tableIdent = processTableIdentifier(tableIdentifier) val tableFullName = getDbTableName(tableIdent) val table = tables.get(tableFullName) @@ -149,7 +171,13 @@ trait OverrideCatalog extends Catalog { abstract override def tableExists(tableIdentifier: Seq[String]): Boolean = { val tableIdent = processTableIdentifier(tableIdentifier) - overrides.get(getDBTable(tableIdent)) match { + // A temporary tables only has a single part in the tableIdentifier. + val overriddenTable = if (tableIdentifier.length > 1) { + None: Option[LogicalPlan] + } else { + overrides.get(getDBTable(tableIdent)) + } + overriddenTable match { case Some(_) => true case None => super.tableExists(tableIdentifier) } @@ -159,7 +187,12 @@ trait OverrideCatalog extends Catalog { tableIdentifier: Seq[String], alias: Option[String] = None): LogicalPlan = { val tableIdent = processTableIdentifier(tableIdentifier) - val overriddenTable = overrides.get(getDBTable(tableIdent)) + // A temporary tables only has a single part in the tableIdentifier. + val overriddenTable = if (tableIdentifier.length > 1) { + None: Option[LogicalPlan] + } else { + overrides.get(getDBTable(tableIdent)) + } val tableWithQualifers = overriddenTable.map(r => Subquery(tableIdent.last, r)) // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are @@ -171,20 +204,8 @@ trait OverrideCatalog extends Catalog { } abstract override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { - val dbName = if (conf.caseSensitiveAnalysis) { - databaseName - } else { - if (databaseName.isDefined) Some(databaseName.get.toLowerCase) else None - } - - val temporaryTables = overrides.filter { - // If a temporary table does not have an associated database, we should return its name. - case ((None, _), _) => true - // If a temporary table does have an associated database, we should return it if the database - // matches the given database name. - case ((db: Some[String], _), _) if db == dbName => true - case _ => false - }.map { + // We always return all temporary tables. + val temporaryTables = overrides.map { case ((_, tableName), _) => (tableName, true) }.toSeq @@ -194,13 +215,19 @@ trait OverrideCatalog extends Catalog { override def registerTable( tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { + checkTableIdentifier(tableIdentifier) val tableIdent = processTableIdentifier(tableIdentifier) overrides.put(getDBTable(tableIdent), plan) } override def unregisterTable(tableIdentifier: Seq[String]): Unit = { - val tableIdent = processTableIdentifier(tableIdentifier) - overrides.remove(getDBTable(tableIdent)) + // A temporary tables only has a single part in the tableIdentifier. + // If tableIdentifier has more than one parts, it is not a temporary table + // and we do not need to do anything at here. + if (tableIdentifier.length == 1) { + val tableIdent = processTableIdentifier(tableIdentifier) + overrides.remove(getDBTable(tableIdent)) + } } override def unregisterAllTables(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index f0bf1be506411..ce8744b53175b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -218,7 +218,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { case _ => val cmd = CreateTableUsingAsSelect( - tableIdent.unquotedString, + tableIdent, source, temporary = false, partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 58fe75b59f418..126c9c6f839c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -584,9 +584,10 @@ class SQLContext(@transient val sparkContext: SparkContext) tableName: String, source: String, options: Map[String, String]): DataFrame = { + val tableIdent = new SqlParser().parseTableIdentifier(tableName) val cmd = CreateTableUsing( - tableName, + tableIdent, userSpecifiedSchema = None, source, temporary = false, @@ -594,7 +595,7 @@ class SQLContext(@transient val sparkContext: SparkContext) allowExisting = false, managedIfNoPath = false) executePlan(cmd).toRdd - table(tableName) + table(tableIdent) } /** @@ -629,9 +630,10 @@ class SQLContext(@transient val sparkContext: SparkContext) source: String, schema: StructType, options: Map[String, String]): DataFrame = { + val tableIdent = new SqlParser().parseTableIdentifier(tableName) val cmd = CreateTableUsing( - tableName, + tableIdent, userSpecifiedSchema = Some(schema), source, temporary = false, @@ -639,7 +641,7 @@ class SQLContext(@transient val sparkContext: SparkContext) allowExisting = false, managedIfNoPath = false) executePlan(cmd).toRdd - table(tableName) + table(tableIdent) } /** @@ -724,7 +726,10 @@ class SQLContext(@transient val sparkContext: SparkContext) * @since 1.3.0 */ def table(tableName: String): DataFrame = { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + table(new SqlParser().parseTableIdentifier(tableName)) + } + + private def table(tableIdent: TableIdentifier): DataFrame = { DataFrame(this, catalog.lookupRelation(tableIdent.toSeq)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 1fc870d44b578..4df53687a0731 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -395,22 +395,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object DDLStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case CreateTableUsing(tableName, userSpecifiedSchema, provider, true, opts, false, _) => + case CreateTableUsing(tableIdent, userSpecifiedSchema, provider, true, opts, false, _) => ExecutedCommand( CreateTempTableUsing( - tableName, userSpecifiedSchema, provider, opts)) :: Nil + tableIdent, userSpecifiedSchema, provider, opts)) :: Nil case c: CreateTableUsing if !c.temporary => sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") case c: CreateTableUsing if c.temporary && c.allowExisting => sys.error("allowExisting should be set to false when creating a temporary table.") - case CreateTableUsingAsSelect(tableName, provider, true, partitionsCols, mode, opts, query) + case CreateTableUsingAsSelect(tableIdent, provider, true, partitionsCols, mode, opts, query) if partitionsCols.nonEmpty => sys.error("Cannot create temporary partitioned table.") - case CreateTableUsingAsSelect(tableName, provider, true, _, mode, opts, query) => + case CreateTableUsingAsSelect(tableIdent, provider, true, _, mode, opts, query) => val cmd = CreateTempTableUsingAsSelect( - tableName, provider, Array.empty[String], mode, opts, query) + tableIdent, provider, Array.empty[String], mode, opts, query) ExecutedCommand(cmd) :: Nil case c: CreateTableUsingAsSelect if !c.temporary => sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala index 6c462fa30461b..f7a88b98c0b48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala @@ -80,9 +80,9 @@ class DDLParser(parseQuery: String => LogicalPlan) */ protected lazy val createTable: Parser[LogicalPlan] = { // TODO: Support database.table. - (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ ident ~ + (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ tableIdentifier ~ tableCols.? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> restInput).? ^^ { - case temp ~ allowExisting ~ tableName ~ columns ~ provider ~ opts ~ query => + case temp ~ allowExisting ~ tableIdent ~ columns ~ provider ~ opts ~ query => if (temp.isDefined && allowExisting.isDefined) { throw new DDLException( "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.") @@ -104,7 +104,7 @@ class DDLParser(parseQuery: String => LogicalPlan) } val queryPlan = parseQuery(query.get) - CreateTableUsingAsSelect(tableName, + CreateTableUsingAsSelect(tableIdent, provider, temp.isDefined, Array.empty[String], @@ -114,7 +114,7 @@ class DDLParser(parseQuery: String => LogicalPlan) } else { val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields))) CreateTableUsing( - tableName, + tableIdent, userSpecifiedSchema, provider, temp.isDefined, @@ -125,6 +125,12 @@ class DDLParser(parseQuery: String => LogicalPlan) } } + // This is the same as tableIdentifier in SqlParser. + protected lazy val tableIdentifier: Parser[TableIdentifier] = + (ident <~ ".").? ~ ident ^^ { + case maybeDbName ~ tableName => TableIdentifier(tableName, maybeDbName) + } + protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" /* @@ -132,21 +138,15 @@ class DDLParser(parseQuery: String => LogicalPlan) * This will display all columns of table `avroTable` includes column_name,column_type,comment */ protected lazy val describeTable: Parser[LogicalPlan] = - (DESCRIBE ~> opt(EXTENDED)) ~ (ident <~ ".").? ~ ident ^^ { - case e ~ db ~ tbl => - val tblIdentifier = db match { - case Some(dbName) => - Seq(dbName, tbl) - case None => - Seq(tbl) - } - DescribeCommand(UnresolvedRelation(tblIdentifier, None), e.isDefined) + (DESCRIBE ~> opt(EXTENDED)) ~ tableIdentifier ^^ { + case e ~ tableIdent => + DescribeCommand(UnresolvedRelation(tableIdent.toSeq, None), e.isDefined) } protected lazy val refreshTable: Parser[LogicalPlan] = - REFRESH ~> TABLE ~> (ident <~ ".").? ~ ident ^^ { - case maybeDatabaseName ~ tableName => - RefreshTable(TableIdentifier(tableName, maybeDatabaseName)) + REFRESH ~> TABLE ~> tableIdentifier ^^ { + case tableIndet => + RefreshTable(tableIndet) } protected lazy val options: Parser[Map[String, String]] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index ecd304c30cdee..31d6b75e13477 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -53,7 +53,7 @@ case class DescribeCommand( * If it is false, an exception will be thrown */ case class CreateTableUsing( - tableName: String, + tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], provider: String, temporary: Boolean, @@ -71,8 +71,9 @@ case class CreateTableUsing( * can analyze the logical plan that will be used to populate the table. * So, [[PreWriteCheck]] can detect cases that are not allowed. */ +// TODO: Use TableIdentifier instead of String for tableName (SPARK-10104). case class CreateTableUsingAsSelect( - tableName: String, + tableIdent: TableIdentifier, provider: String, temporary: Boolean, partitionColumns: Array[String], @@ -80,12 +81,10 @@ case class CreateTableUsingAsSelect( options: Map[String, String], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = Seq.empty[Attribute] - // TODO: Override resolved after we support databaseName. - // override lazy val resolved = databaseName != None && childrenResolved } case class CreateTempTableUsing( - tableName: String, + tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], provider: String, options: Map[String, String]) extends RunnableCommand { @@ -93,14 +92,16 @@ case class CreateTempTableUsing( def run(sqlContext: SQLContext): Seq[Row] = { val resolved = ResolvedDataSource( sqlContext, userSpecifiedSchema, Array.empty[String], provider, options) - sqlContext.registerDataFrameAsTable( - DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) + sqlContext.catalog.registerTable( + tableIdent.toSeq, + DataFrame(sqlContext, LogicalRelation(resolved.relation)).logicalPlan) + Seq.empty[Row] } } case class CreateTempTableUsingAsSelect( - tableName: String, + tableIdent: TableIdentifier, provider: String, partitionColumns: Array[String], mode: SaveMode, @@ -110,8 +111,9 @@ case class CreateTempTableUsingAsSelect( override def run(sqlContext: SQLContext): Seq[Row] = { val df = DataFrame(sqlContext, query) val resolved = ResolvedDataSource(sqlContext, provider, partitionColumns, mode, options, df) - sqlContext.registerDataFrameAsTable( - DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) + sqlContext.catalog.registerTable( + tableIdent.toSeq, + DataFrame(sqlContext, LogicalRelation(resolved.relation)).logicalPlan) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 9d3d35692ffcc..16c9138419fa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -140,12 +140,12 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // OK } - case CreateTableUsingAsSelect(tableName, _, _, partitionColumns, mode, _, query) => + case CreateTableUsingAsSelect(tableIdent, _, _, partitionColumns, mode, _, query) => // When the SaveMode is Overwrite, we need to check if the table is an input table of // the query. If so, we will throw an AnalysisException to let users know it is not allowed. - if (mode == SaveMode.Overwrite && catalog.tableExists(Seq(tableName))) { + if (mode == SaveMode.Overwrite && catalog.tableExists(tableIdent.toSeq)) { // Need to remove SubQuery operator. - EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) match { + EliminateSubQueries(catalog.lookupRelation(tableIdent.toSeq)) match { // Only do the check if the table is a data source table // (the relation is a BaseRelation). case l @ LogicalRelation(dest: BaseRelation) => @@ -155,7 +155,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => } if (srcRelations.contains(dest)) { failAnalysis( - s"Cannot overwrite table $tableName that is also being read from.") + s"Cannot overwrite table $tableIdent that is also being read from.") } else { // OK } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 141468ca00d67..da50aec17c89e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1644,4 +1644,39 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(sql("select count(num) from 1one"), Row(10)) } } + + test("specifying database name for a temporary table is not allowed") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = + sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") + df + .write + .format("parquet") + .save(path) + + val message = intercept[AnalysisException] { + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE db.t + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + assert(message.contains("Specifying database name or other qualifiers are not allowed")) + + // If you use backticks to quote the name of a temporary table having dot in it. + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE `db.t` + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + checkAnswer(sqlContext.table("`db.t`"), df) + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 17762649fd70d..17cc83087fb1d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -43,7 +43,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.SQLConf.SQLConfEntry._ -import org.apache.spark.sql.catalyst.{TableIdentifier, ParserDialect} +import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier, ParserDialect} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUDFs, SetCommand} @@ -189,6 +189,10 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { // We instantiate a HiveConf here to read in the hive-site.xml file and then pass the options // into the isolated client loader val metadataConf = new HiveConf() + + val defaltWarehouseLocation = metadataConf.get("hive.metastore.warehouse.dir") + logInfo("defalt warehouse location is " + defaltWarehouseLocation) + // `configure` goes second to override other settings. val allConfig = metadataConf.iterator.map(e => e.getKey -> e.getValue).toMap ++ configure @@ -288,12 +292,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { * @since 1.3.0 */ def refreshTable(tableName: String): Unit = { - val tableIdent = TableIdentifier(tableName).withDatabase(catalog.client.currentDatabase) + val tableIdent = new SqlParser().parseTableIdentifier(tableName) catalog.refreshTable(tableIdent) } protected[hive] def invalidateTable(tableName: String): Unit = { - catalog.invalidateTable(catalog.client.currentDatabase, tableName) + val tableIdent = new SqlParser().parseTableIdentifier(tableName) + catalog.invalidateTable(tableIdent) } /** @@ -307,7 +312,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { */ @Experimental def analyze(tableName: String) { - val relation = EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) + val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val relation = EliminateSubQueries(catalog.lookupRelation(tableIdent.toSeq)) relation match { case relation: MetastoreRelation => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 6770462bb0ad3..bbe8c1911bf86 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -174,10 +174,13 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // it is better at here to invalidate the cache to avoid confusing waring logs from the // cache loader (e.g. cannot find data source provider, which is only defined for // data source table.). - invalidateTable(tableIdent.database.getOrElse(client.currentDatabase), tableIdent.table) + invalidateTable(tableIdent) } - def invalidateTable(databaseName: String, tableName: String): Unit = { + def invalidateTable(tableIdent: TableIdentifier): Unit = { + val databaseName = tableIdent.database.getOrElse(client.currentDatabase) + val tableName = tableIdent.table + cachedDataSourceTables.invalidate(QualifiedTableName(databaseName, tableName).toLowerCase) } @@ -187,6 +190,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive * Creates a data source table (a table created with USING clause) in Hive's metastore. * Returns true when the table has been created. Otherwise, false. */ + // TODO: Remove this in SPARK-10104. def createDataSourceTable( tableName: String, userSpecifiedSchema: Option[StructType], @@ -203,7 +207,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive isExternal) } - private def createDataSourceTable( + def createDataSourceTable( tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], partitionColumns: Array[String], @@ -371,10 +375,16 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } def hiveDefaultTableFilePath(tableName: String): String = { + hiveDefaultTableFilePath(new SqlParser().parseTableIdentifier(tableName)) + } + + def hiveDefaultTableFilePath(tableIdent: TableIdentifier): String = { // Code based on: hiveWarehouse.getTablePath(currentDatabase, tableName) + val database = tableIdent.database.getOrElse(client.currentDatabase) + new Path( - new Path(client.getDatabase(client.currentDatabase).location), - tableName.toLowerCase).toString + new Path(client.getDatabase(database).location), + tableIdent.table.toLowerCase).toString } def tableExists(tableIdentifier: Seq[String]): Boolean = { @@ -635,7 +645,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val mode = if (allowExisting) SaveMode.Ignore else SaveMode.ErrorIfExists CreateTableUsingAsSelect( - desc.name, + TableIdentifier(desc.name), hive.conf.defaultDataSourceName, temporary = false, Array.empty[String], diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index cd6cd322c94ed..d38ad9127327d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -83,14 +83,16 @@ private[hive] trait HiveStrategies { object HiveDDLStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case CreateTableUsing( - tableName, userSpecifiedSchema, provider, false, opts, allowExisting, managedIfNoPath) => - ExecutedCommand( + tableIdent, userSpecifiedSchema, provider, false, opts, allowExisting, managedIfNoPath) => + val cmd = CreateMetastoreDataSource( - tableName, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath)) :: Nil + tableIdent, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath) + ExecutedCommand(cmd) :: Nil - case CreateTableUsingAsSelect(tableName, provider, false, partitionCols, mode, opts, query) => + case CreateTableUsingAsSelect( + tableIdent, provider, false, partitionCols, mode, opts, query) => val cmd = - CreateMetastoreDataSourceAsSelect(tableName, provider, partitionCols, mode, opts, query) + CreateMetastoreDataSourceAsSelect(tableIdent, provider, partitionCols, mode, opts, query) ExecutedCommand(cmd) :: Nil case _ => Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 05a78930afe3d..d1699dd536817 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.hive.execution +import org.apache.hadoop.hive.metastore.MetaStoreUtils import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.{TableIdentifier, SqlParser} import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -120,9 +122,10 @@ case class AddFile(path: String) extends RunnableCommand { } } +// TODO: Use TableIdentifier instead of String for tableName (SPARK-10104). private[hive] case class CreateMetastoreDataSource( - tableName: String, + tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], provider: String, options: Map[String, String], @@ -130,9 +133,24 @@ case class CreateMetastoreDataSource( managedIfNoPath: Boolean) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { + // Since we are saving metadata to metastore, we need to check if metastore supports + // the table name and database name we have for this query. MetaStoreUtils.validateName + // is the method used by Hive to check if a table name or a database name is valid for + // the metastore. + if (!MetaStoreUtils.validateName(tableIdent.table)) { + throw new AnalysisException(s"Table name ${tableIdent.table} is not a valid name for " + + s"metastore. Metastore only accepts table name containing characters, numbers and _.") + } + if (tableIdent.database.isDefined && !MetaStoreUtils.validateName(tableIdent.database.get)) { + throw new AnalysisException(s"Database name ${tableIdent.database.get} is not a valid name " + + s"for metastore. Metastore only accepts database name containing " + + s"characters, numbers and _.") + } + + val tableName = tableIdent.unquotedString val hiveContext = sqlContext.asInstanceOf[HiveContext] - if (hiveContext.catalog.tableExists(tableName :: Nil)) { + if (hiveContext.catalog.tableExists(tableIdent.toSeq)) { if (allowExisting) { return Seq.empty[Row] } else { @@ -144,13 +162,13 @@ case class CreateMetastoreDataSource( val optionsWithPath = if (!options.contains("path") && managedIfNoPath) { isExternal = false - options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableName)) + options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableIdent)) } else { options } hiveContext.catalog.createDataSourceTable( - tableName, + tableIdent, userSpecifiedSchema, Array.empty[String], provider, @@ -161,9 +179,10 @@ case class CreateMetastoreDataSource( } } +// TODO: Use TableIdentifier instead of String for tableName (SPARK-10104). private[hive] case class CreateMetastoreDataSourceAsSelect( - tableName: String, + tableIdent: TableIdentifier, provider: String, partitionColumns: Array[String], mode: SaveMode, @@ -171,19 +190,34 @@ case class CreateMetastoreDataSourceAsSelect( query: LogicalPlan) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { + // Since we are saving metadata to metastore, we need to check if metastore supports + // the table name and database name we have for this query. MetaStoreUtils.validateName + // is the method used by Hive to check if a table name or a database name is valid for + // the metastore. + if (!MetaStoreUtils.validateName(tableIdent.table)) { + throw new AnalysisException(s"Table name ${tableIdent.table} is not a valid name for " + + s"metastore. Metastore only accepts table name containing characters, numbers and _.") + } + if (tableIdent.database.isDefined && !MetaStoreUtils.validateName(tableIdent.database.get)) { + throw new AnalysisException(s"Database name ${tableIdent.database.get} is not a valid name " + + s"for metastore. Metastore only accepts database name containing " + + s"characters, numbers and _.") + } + + val tableName = tableIdent.unquotedString val hiveContext = sqlContext.asInstanceOf[HiveContext] var createMetastoreTable = false var isExternal = true val optionsWithPath = if (!options.contains("path")) { isExternal = false - options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableName)) + options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableIdent)) } else { options } var existingSchema = None: Option[StructType] - if (sqlContext.catalog.tableExists(Seq(tableName))) { + if (sqlContext.catalog.tableExists(tableIdent.toSeq)) { // Check if we need to throw an exception or just return. mode match { case SaveMode.ErrorIfExists => @@ -200,7 +234,7 @@ case class CreateMetastoreDataSourceAsSelect( val resolved = ResolvedDataSource( sqlContext, Some(query.schema.asNullable), partitionColumns, provider, optionsWithPath) val createdRelation = LogicalRelation(resolved.relation) - EliminateSubQueries(sqlContext.table(tableName).logicalPlan) match { + EliminateSubQueries(sqlContext.catalog.lookupRelation(tableIdent.toSeq)) match { case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation) => if (l.relation != createdRelation.relation) { val errorDescription = @@ -249,7 +283,7 @@ case class CreateMetastoreDataSourceAsSelect( // the schema of df). It is important since the nullability may be changed by the relation // provider (for example, see org.apache.spark.sql.parquet.DefaultSource). hiveContext.catalog.createDataSourceTable( - tableName, + tableIdent, Some(resolved.relation.schema), partitionColumns, provider, @@ -258,7 +292,7 @@ case class CreateMetastoreDataSourceAsSelect( } // Refresh the cache of the table in the catalog. - hiveContext.refreshTable(tableName) + hiveContext.catalog.refreshTable(tableIdent) Seq.empty[Row] } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index 1c15997ea8e6d..d3388a9429e41 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -34,7 +34,6 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll { override def beforeAll(): Unit = { // The catalog in HiveContext is a case insensitive one. catalog.registerTable(Seq("ListTablesSuiteTable"), df.logicalPlan) - catalog.registerTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"), df.logicalPlan) sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)") sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB") sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)") @@ -42,7 +41,6 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll { override def afterAll(): Unit = { catalog.unregisterTable(Seq("ListTablesSuiteTable")) - catalog.unregisterTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable")) sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable") sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable") sql("DROP DATABASE IF EXISTS ListTablesSuiteDB") @@ -55,7 +53,6 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll { checkAnswer( allTables.filter("tableName = 'listtablessuitetable'"), Row("listtablessuitetable", true)) - assert(allTables.filter("tableName = 'indblisttablessuitetable'").count() === 0) checkAnswer( allTables.filter("tableName = 'hivelisttablessuitetable'"), Row("hivelisttablessuitetable", false)) @@ -69,9 +66,6 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll { checkAnswer( allTables.filter("tableName = 'listtablessuitetable'"), Row("listtablessuitetable", true)) - checkAnswer( - allTables.filter("tableName = 'indblisttablessuitetable'"), - Row("indblisttablessuitetable", true)) assert(allTables.filter("tableName = 'hivelisttablessuitetable'").count() === 0) checkAnswer( allTables.filter("tableName = 'hiveindblisttablessuitetable'"), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index 417e8b07917cc..997c667ec0d1b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -19,14 +19,22 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.{QueryTest, SQLContext, SaveMode} +import org.apache.spark.sql.{AnalysisException, QueryTest, SQLContext, SaveMode} class MultiDatabaseSuite extends QueryTest with SQLTestUtils { - override val _sqlContext: SQLContext = TestHive + override val _sqlContext: HiveContext = TestHive private val sqlContext = _sqlContext private val df = sqlContext.range(10).coalesce(1) + private def checkTablePath(dbName: String, tableName: String): Unit = { + // val hiveContext = sqlContext.asInstanceOf[HiveContext] + val metastoreTable = sqlContext.catalog.client.getTable(dbName, tableName) + val expectedPath = sqlContext.catalog.client.getDatabase(dbName).location + "/" + tableName + + assert(metastoreTable.serdeProperties("path") === expectedPath) + } + test(s"saveAsTable() to non-default database - with USE - Overwrite") { withTempDatabase { db => activateDatabase(db) { @@ -37,6 +45,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { assert(sqlContext.tableNames(db).contains("t")) checkAnswer(sqlContext.table(s"$db.t"), df) + + checkTablePath(db, "t") } } @@ -45,6 +55,58 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") assert(sqlContext.tableNames(db).contains("t")) checkAnswer(sqlContext.table(s"$db.t"), df) + + checkTablePath(db, "t") + } + } + + test(s"createExternalTable() to non-default database - with USE") { + withTempDatabase { db => + activateDatabase(db) { + withTempPath { dir => + val path = dir.getCanonicalPath + df.write.format("parquet").mode(SaveMode.Overwrite).save(path) + + sqlContext.createExternalTable("t", path, "parquet") + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table("t"), df) + + sql( + s""" + |CREATE TABLE t1 + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + assert(sqlContext.tableNames(db).contains("t1")) + checkAnswer(sqlContext.table("t1"), df) + } + } + } + } + + test(s"createExternalTable() to non-default database - without USE") { + withTempDatabase { db => + withTempPath { dir => + val path = dir.getCanonicalPath + df.write.format("parquet").mode(SaveMode.Overwrite).save(path) + sqlContext.createExternalTable(s"$db.t", path, "parquet") + + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df) + + sql( + s""" + |CREATE TABLE $db.t1 + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + assert(sqlContext.tableNames(db).contains("t1")) + checkAnswer(sqlContext.table(s"$db.t1"), df) + } } } @@ -59,6 +121,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { assert(sqlContext.tableNames(db).contains("t")) checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + + checkTablePath(db, "t") } } @@ -68,6 +132,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { df.write.mode(SaveMode.Append).saveAsTable(s"$db.t") assert(sqlContext.tableNames(db).contains("t")) checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + + checkTablePath(db, "t") } } @@ -130,7 +196,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { } } - test("Refreshes a table in a non-default database") { + test("Refreshes a table in a non-default database - with USE") { import org.apache.spark.sql.functions.lit withTempDatabase { db => @@ -151,8 +217,94 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils { sql("ALTER TABLE t ADD PARTITION (p=1)") sql("REFRESH TABLE t") checkAnswer(sqlContext.table("t"), df.withColumn("p", lit(1))) + + df.write.parquet(s"$path/p=2") + sql("ALTER TABLE t ADD PARTITION (p=2)") + sqlContext.refreshTable("t") + checkAnswer( + sqlContext.table("t"), + df.withColumn("p", lit(1)).unionAll(df.withColumn("p", lit(2)))) } } } } + + test("Refreshes a table in a non-default database - without USE") { + import org.apache.spark.sql.functions.lit + + withTempDatabase { db => + withTempPath { dir => + val path = dir.getCanonicalPath + + sql( + s"""CREATE EXTERNAL TABLE $db.t (id BIGINT) + |PARTITIONED BY (p INT) + |STORED AS PARQUET + |LOCATION '$path' + """.stripMargin) + + checkAnswer(sqlContext.table(s"$db.t"), sqlContext.emptyDataFrame) + + df.write.parquet(s"$path/p=1") + sql(s"ALTER TABLE $db.t ADD PARTITION (p=1)") + sql(s"REFRESH TABLE $db.t") + checkAnswer(sqlContext.table(s"$db.t"), df.withColumn("p", lit(1))) + + df.write.parquet(s"$path/p=2") + sql(s"ALTER TABLE $db.t ADD PARTITION (p=2)") + sqlContext.refreshTable(s"$db.t") + checkAnswer( + sqlContext.table(s"$db.t"), + df.withColumn("p", lit(1)).unionAll(df.withColumn("p", lit(2)))) + } + } + } + + test("invalid database name and table names") { + { + val message = intercept[AnalysisException] { + df.write.format("parquet").saveAsTable("`d:b`.`t:a`") + }.getMessage + assert(message.contains("is not a valid name for metastore")) + } + + { + val message = intercept[AnalysisException] { + df.write.format("parquet").saveAsTable("`d:b`.`table`") + }.getMessage + assert(message.contains("is not a valid name for metastore")) + } + + withTempPath { dir => + val path = dir.getCanonicalPath + + { + val message = intercept[AnalysisException] { + sql( + s""" + |CREATE TABLE `d:b`.`t:a` (a int) + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + assert(message.contains("is not a valid name for metastore")) + } + + { + val message = intercept[AnalysisException] { + sql( + s""" + |CREATE TABLE `d:b`.`table` (a int) + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + assert(message.contains("is not a valid name for metastore")) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 8b8f520776e70..55ecbd5b5f21d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1138,4 +1138,39 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { Row(CalendarInterval.fromString( "interval 4 minutes 59 seconds 889 milliseconds 987 microseconds"))) } + + test("specifying database name for a temporary table is not allowed") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = + sqlContext.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("num", "str") + df + .write + .format("parquet") + .save(path) + + val message = intercept[AnalysisException] { + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE db.t + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + }.getMessage + assert(message.contains("Specifying database name or other qualifiers are not allowed")) + + // If you use backticks to quote the name of a temporary table having dot in it. + sqlContext.sql( + s""" + |CREATE TEMPORARY TABLE `db.t` + |USING parquet + |OPTIONS ( + | path '$path' + |) + """.stripMargin) + checkAnswer(sqlContext.table("`db.t`"), df) + } + } } From b4f4e91c395cb69ced61d9ff1492d1b814f96828 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 20 Aug 2015 07:53:27 -0700 Subject: [PATCH 019/802] [SPARK-10100] [SQL] Eliminate hash table lookup if there is no grouping key in aggregation. This improves performance by ~ 20 - 30% in one of my local test and should fix the performance regression from 1.4 to 1.5 on ss_max. Author: Reynold Xin Closes #8332 from rxin/SPARK-10100. --- .../aggregate/TungstenAggregate.scala | 2 +- .../TungstenAggregationIterator.scala | 30 +++++++++++++------ 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 99f51ba5b6935..ba379d358d206 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -104,7 +104,7 @@ case class TungstenAggregate( } else { // This is a grouped aggregate and the input iterator is empty, // so return an empty iterator. - Iterator[UnsafeRow]() + Iterator.empty } } else { aggregationIterator.start(parentIterator) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index af7e0fcedbe4e..26fdbc83ef50b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -357,18 +357,30 @@ class TungstenAggregationIterator( // sort-based aggregation (by calling switchToSortBasedAggregation). private def processInputs(): Unit = { assert(inputIter != null, "attempted to process input when iterator was null") - while (!sortBased && inputIter.hasNext) { - val newInput = inputIter.next() - numInputRows += 1 - val groupingKey = groupProjection.apply(newInput) + if (groupingExpressions.isEmpty) { + // If there is no grouping expressions, we can just reuse the same buffer over and over again. + // Note that it would be better to eliminate the hash map entirely in the future. + val groupingKey = groupProjection.apply(null) val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) - if (buffer == null) { - // buffer == null means that we could not allocate more memory. - // Now, we need to spill the map and switch to sort-based aggregation. - switchToSortBasedAggregation(groupingKey, newInput) - } else { + while (inputIter.hasNext) { + val newInput = inputIter.next() + numInputRows += 1 processRow(buffer, newInput) } + } else { + while (!sortBased && inputIter.hasNext) { + val newInput = inputIter.next() + numInputRows += 1 + val groupingKey = groupProjection.apply(newInput) + val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) + if (buffer == null) { + // buffer == null means that we could not allocate more memory. + // Now, we need to spill the map and switch to sort-based aggregation. + switchToSortBasedAggregation(groupingKey, newInput) + } else { + processRow(buffer, newInput) + } + } } } From 52c60537a274af5414f6b0340a4bd7488ef35280 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Thu, 20 Aug 2015 10:05:31 -0700 Subject: [PATCH 020/802] [MINOR] [SQL] Fix sphinx warnings in PySpark SQL Author: MechCoder Closes #8171 from MechCoder/sql_sphinx. --- python/pyspark/context.py | 8 ++++---- python/pyspark/sql/types.py | 4 +++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index eb5b0bbbdac4b..1b2a52ad64114 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -302,10 +302,10 @@ def applicationId(self): """ A unique identifier for the Spark application. Its format depends on the scheduler implementation. - (i.e. - in case of local spark app something like 'local-1433865536131' - in case of YARN something like 'application_1433865536131_34483' - ) + + * in case of local spark app something like 'local-1433865536131' + * in case of YARN something like 'application_1433865536131_34483' + >>> sc.applicationId # doctest: +ELLIPSIS u'local-...' """ diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index c083bf89905bf..ed4e5b594bd61 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -467,9 +467,11 @@ def add(self, field, data_type=None, nullable=True, metadata=None): """ Construct a StructType by adding new elements to it to define the schema. The method accepts either: + a) A single parameter which is a StructField object. b) Between 2 and 4 parameters as (name, data_type, nullable (optional), - metadata(optional). The data_type parameter may be either a String or a DataType object + metadata(optional). The data_type parameter may be either a String or a + DataType object. >>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) >>> struct2 = StructType([StructField("f1", StringType(), True),\ From 39e91fe2fd43044cc734d55625a3c03284b69f09 Mon Sep 17 00:00:00 2001 From: Alex Shkurenko Date: Thu, 20 Aug 2015 10:16:38 -0700 Subject: [PATCH 021/802] [SPARK-9982] [SPARKR] SparkR DataFrame fail to return data of Decimal type Author: Alex Shkurenko Closes #8239 from ashkurenko/master. --- core/src/main/scala/org/apache/spark/api/r/SerDe.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index d5b4260bf4529..3c89f24473744 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -181,6 +181,7 @@ private[spark] object SerDe { // Boolean -> logical // Float -> double // Double -> double + // Decimal -> double // Long -> double // Array[Byte] -> raw // Date -> Date @@ -219,6 +220,10 @@ private[spark] object SerDe { case "float" | "java.lang.Float" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Float].toDouble) + case "decimal" | "java.math.BigDecimal" => + writeType(dos, "double") + val javaDecimal = value.asInstanceOf[java.math.BigDecimal] + writeDouble(dos, scala.math.BigDecimal(javaDecimal).toDouble) case "double" | "java.lang.Double" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Double]) From 85f9a61357994da5023b08b0a8a2eb09388ce7f8 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 20 Aug 2015 11:00:24 -0700 Subject: [PATCH 022/802] [SPARK-10136] [SQL] Fixes Parquet support for Avro array of primitive array I caught SPARK-10136 while adding more test cases to `ParquetAvroCompatibilitySuite`. Actual bug fix code lies in `CatalystRowConverter.scala`. Author: Cheng Lian Closes #8341 from liancheng/spark-10136/parquet-avro-nested-primitive-array. --- .../parquet/CatalystReadSupport.scala | 1 - .../parquet/CatalystRowConverter.scala | 24 +- sql/core/src/test/avro/parquet-compat.avdl | 19 +- sql/core/src/test/avro/parquet-compat.avpr | 54 +- .../parquet/test/avro/AvroArrayOfArray.java | 142 +++ .../parquet/test/avro/AvroMapOfArray.java | 142 +++ .../test/avro/AvroNonNullableArrays.java | 196 +++++ .../test/avro/AvroOptionalPrimitives.java | 466 ++++++++++ .../parquet/test/avro/AvroPrimitives.java | 461 ++++++++++ .../parquet/test/avro/CompatibilityTest.java | 2 +- .../parquet/test/avro/ParquetAvroCompat.java | 821 +----------------- .../ParquetAvroCompatibilitySuite.scala | 227 +++-- .../parquet/ParquetCompatibilityTest.scala | 7 + 13 files changed, 1718 insertions(+), 844 deletions(-) create mode 100644 sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroArrayOfArray.java create mode 100644 sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroMapOfArray.java create mode 100644 sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroNonNullableArrays.java create mode 100644 sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroOptionalPrimitives.java create mode 100644 sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroPrimitives.java diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala index a4679bb2f6389..3f8353af6e2ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala @@ -61,7 +61,6 @@ private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with | |Parquet form: |$parquetRequestedSchema - | |Catalyst form: |$catalystRequestedSchema """.stripMargin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 18c5b500209e6..d2c2db51769ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -25,11 +25,12 @@ import scala.collection.mutable.ArrayBuffer import org.apache.parquet.column.Dictionary import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} -import org.apache.parquet.schema.OriginalType.{LIST, INT_32, UTF8} +import org.apache.parquet.schema.OriginalType.{INT_32, LIST, UTF8} import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.DOUBLE import org.apache.parquet.schema.Type.Repetition import org.apache.parquet.schema.{GroupType, MessageType, PrimitiveType, Type} +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -145,7 +146,16 @@ private[parquet] class CatalystRowConverter( parquetType: GroupType, catalystType: StructType, updater: ParentContainerUpdater) - extends CatalystGroupConverter(updater) { + extends CatalystGroupConverter(updater) with Logging { + + logDebug( + s"""Building row converter for the following schema: + | + |Parquet form: + |$parquetType + |Catalyst form: + |${catalystType.prettyJson} + """.stripMargin) /** * Updater used together with field converters within a [[CatalystRowConverter]]. It propagates @@ -464,9 +474,15 @@ private[parquet] class CatalystRowConverter( override def getConverter(fieldIndex: Int): Converter = converter - override def end(): Unit = currentArray += currentElement + override def end(): Unit = { + converter.updater.end() + currentArray += currentElement + } - override def start(): Unit = currentElement = null + override def start(): Unit = { + converter.updater.start() + currentElement = null + } } } diff --git a/sql/core/src/test/avro/parquet-compat.avdl b/sql/core/src/test/avro/parquet-compat.avdl index 8070d0a9170a3..c5eb5b5164cf4 100644 --- a/sql/core/src/test/avro/parquet-compat.avdl +++ b/sql/core/src/test/avro/parquet-compat.avdl @@ -34,7 +34,7 @@ protocol CompatibilityTest { string nested_string_column; } - record ParquetAvroCompat { + record AvroPrimitives { boolean bool_column; int int_column; long long_column; @@ -42,7 +42,9 @@ protocol CompatibilityTest { double double_column; bytes binary_column; string string_column; + } + record AvroOptionalPrimitives { union { null, boolean } maybe_bool_column; union { null, int } maybe_int_column; union { null, long } maybe_long_column; @@ -50,7 +52,22 @@ protocol CompatibilityTest { union { null, double } maybe_double_column; union { null, bytes } maybe_binary_column; union { null, string } maybe_string_column; + } + + record AvroNonNullableArrays { + array strings_column; + union { null, array } maybe_ints_column; + } + record AvroArrayOfArray { + array> int_arrays_column; + } + + record AvroMapOfArray { + map> string_to_ints_column; + } + + record ParquetAvroCompat { array strings_column; map string_to_int_column; map> complex_column; diff --git a/sql/core/src/test/avro/parquet-compat.avpr b/sql/core/src/test/avro/parquet-compat.avpr index 060391765034b..9ad315b74fb41 100644 --- a/sql/core/src/test/avro/parquet-compat.avpr +++ b/sql/core/src/test/avro/parquet-compat.avpr @@ -27,7 +27,7 @@ } ] }, { "type" : "record", - "name" : "ParquetAvroCompat", + "name" : "AvroPrimitives", "fields" : [ { "name" : "bool_column", "type" : "boolean" @@ -49,7 +49,11 @@ }, { "name" : "string_column", "type" : "string" - }, { + } ] + }, { + "type" : "record", + "name" : "AvroOptionalPrimitives", + "fields" : [ { "name" : "maybe_bool_column", "type" : [ "null", "boolean" ] }, { @@ -70,7 +74,53 @@ }, { "name" : "maybe_string_column", "type" : [ "null", "string" ] + } ] + }, { + "type" : "record", + "name" : "AvroNonNullableArrays", + "fields" : [ { + "name" : "strings_column", + "type" : { + "type" : "array", + "items" : "string" + } }, { + "name" : "maybe_ints_column", + "type" : [ "null", { + "type" : "array", + "items" : "int" + } ] + } ] + }, { + "type" : "record", + "name" : "AvroArrayOfArray", + "fields" : [ { + "name" : "int_arrays_column", + "type" : { + "type" : "array", + "items" : { + "type" : "array", + "items" : "int" + } + } + } ] + }, { + "type" : "record", + "name" : "AvroMapOfArray", + "fields" : [ { + "name" : "string_to_ints_column", + "type" : { + "type" : "map", + "values" : { + "type" : "array", + "items" : "int" + } + } + } ] + }, { + "type" : "record", + "name" : "ParquetAvroCompat", + "fields" : [ { "name" : "strings_column", "type" : { "type" : "array", diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroArrayOfArray.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroArrayOfArray.java new file mode 100644 index 0000000000000..ee327827903e5 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroArrayOfArray.java @@ -0,0 +1,142 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class AvroArrayOfArray extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"AvroArrayOfArray\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"int_arrays_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"array\",\"items\":\"int\"}}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.List> int_arrays_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public AvroArrayOfArray() {} + + /** + * All-args constructor. + */ + public AvroArrayOfArray(java.util.List> int_arrays_column) { + this.int_arrays_column = int_arrays_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return int_arrays_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: int_arrays_column = (java.util.List>)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'int_arrays_column' field. + */ + public java.util.List> getIntArraysColumn() { + return int_arrays_column; + } + + /** + * Sets the value of the 'int_arrays_column' field. + * @param value the value to set. + */ + public void setIntArraysColumn(java.util.List> value) { + this.int_arrays_column = value; + } + + /** Creates a new AvroArrayOfArray RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder(); + } + + /** Creates a new AvroArrayOfArray RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder(other); + } + + /** Creates a new AvroArrayOfArray RecordBuilder by copying an existing AvroArrayOfArray instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder(other); + } + + /** + * RecordBuilder for AvroArrayOfArray instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.List> int_arrays_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder other) { + super(other); + if (isValidValue(fields()[0], other.int_arrays_column)) { + this.int_arrays_column = data().deepCopy(fields()[0].schema(), other.int_arrays_column); + fieldSetFlags()[0] = true; + } + } + + /** Creates a Builder by copying an existing AvroArrayOfArray instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.SCHEMA$); + if (isValidValue(fields()[0], other.int_arrays_column)) { + this.int_arrays_column = data().deepCopy(fields()[0].schema(), other.int_arrays_column); + fieldSetFlags()[0] = true; + } + } + + /** Gets the value of the 'int_arrays_column' field */ + public java.util.List> getIntArraysColumn() { + return int_arrays_column; + } + + /** Sets the value of the 'int_arrays_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder setIntArraysColumn(java.util.List> value) { + validate(fields()[0], value); + this.int_arrays_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'int_arrays_column' field has been set */ + public boolean hasIntArraysColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'int_arrays_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroArrayOfArray.Builder clearIntArraysColumn() { + int_arrays_column = null; + fieldSetFlags()[0] = false; + return this; + } + + @Override + public AvroArrayOfArray build() { + try { + AvroArrayOfArray record = new AvroArrayOfArray(); + record.int_arrays_column = fieldSetFlags()[0] ? this.int_arrays_column : (java.util.List>) defaultValue(fields()[0]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroMapOfArray.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroMapOfArray.java new file mode 100644 index 0000000000000..727f6a7bf733e --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroMapOfArray.java @@ -0,0 +1,142 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class AvroMapOfArray extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"AvroMapOfArray\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"string_to_ints_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":\"int\"},\"avro.java.string\":\"String\"}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.Map> string_to_ints_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public AvroMapOfArray() {} + + /** + * All-args constructor. + */ + public AvroMapOfArray(java.util.Map> string_to_ints_column) { + this.string_to_ints_column = string_to_ints_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return string_to_ints_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: string_to_ints_column = (java.util.Map>)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'string_to_ints_column' field. + */ + public java.util.Map> getStringToIntsColumn() { + return string_to_ints_column; + } + + /** + * Sets the value of the 'string_to_ints_column' field. + * @param value the value to set. + */ + public void setStringToIntsColumn(java.util.Map> value) { + this.string_to_ints_column = value; + } + + /** Creates a new AvroMapOfArray RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder(); + } + + /** Creates a new AvroMapOfArray RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder(other); + } + + /** Creates a new AvroMapOfArray RecordBuilder by copying an existing AvroMapOfArray instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder(other); + } + + /** + * RecordBuilder for AvroMapOfArray instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.Map> string_to_ints_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder other) { + super(other); + if (isValidValue(fields()[0], other.string_to_ints_column)) { + this.string_to_ints_column = data().deepCopy(fields()[0].schema(), other.string_to_ints_column); + fieldSetFlags()[0] = true; + } + } + + /** Creates a Builder by copying an existing AvroMapOfArray instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.SCHEMA$); + if (isValidValue(fields()[0], other.string_to_ints_column)) { + this.string_to_ints_column = data().deepCopy(fields()[0].schema(), other.string_to_ints_column); + fieldSetFlags()[0] = true; + } + } + + /** Gets the value of the 'string_to_ints_column' field */ + public java.util.Map> getStringToIntsColumn() { + return string_to_ints_column; + } + + /** Sets the value of the 'string_to_ints_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder setStringToIntsColumn(java.util.Map> value) { + validate(fields()[0], value); + this.string_to_ints_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'string_to_ints_column' field has been set */ + public boolean hasStringToIntsColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'string_to_ints_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroMapOfArray.Builder clearStringToIntsColumn() { + string_to_ints_column = null; + fieldSetFlags()[0] = false; + return this; + } + + @Override + public AvroMapOfArray build() { + try { + AvroMapOfArray record = new AvroMapOfArray(); + record.string_to_ints_column = fieldSetFlags()[0] ? this.string_to_ints_column : (java.util.Map>) defaultValue(fields()[0]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroNonNullableArrays.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroNonNullableArrays.java new file mode 100644 index 0000000000000..934793f42f9c9 --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroNonNullableArrays.java @@ -0,0 +1,196 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class AvroNonNullableArrays extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"AvroNonNullableArrays\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"maybe_ints_column\",\"type\":[\"null\",{\"type\":\"array\",\"items\":\"int\"}]}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.util.List strings_column; + @Deprecated public java.util.List maybe_ints_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public AvroNonNullableArrays() {} + + /** + * All-args constructor. + */ + public AvroNonNullableArrays(java.util.List strings_column, java.util.List maybe_ints_column) { + this.strings_column = strings_column; + this.maybe_ints_column = maybe_ints_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return strings_column; + case 1: return maybe_ints_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: strings_column = (java.util.List)value$; break; + case 1: maybe_ints_column = (java.util.List)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'strings_column' field. + */ + public java.util.List getStringsColumn() { + return strings_column; + } + + /** + * Sets the value of the 'strings_column' field. + * @param value the value to set. + */ + public void setStringsColumn(java.util.List value) { + this.strings_column = value; + } + + /** + * Gets the value of the 'maybe_ints_column' field. + */ + public java.util.List getMaybeIntsColumn() { + return maybe_ints_column; + } + + /** + * Sets the value of the 'maybe_ints_column' field. + * @param value the value to set. + */ + public void setMaybeIntsColumn(java.util.List value) { + this.maybe_ints_column = value; + } + + /** Creates a new AvroNonNullableArrays RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder(); + } + + /** Creates a new AvroNonNullableArrays RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder(other); + } + + /** Creates a new AvroNonNullableArrays RecordBuilder by copying an existing AvroNonNullableArrays instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder(other); + } + + /** + * RecordBuilder for AvroNonNullableArrays instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.util.List strings_column; + private java.util.List maybe_ints_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder other) { + super(other); + if (isValidValue(fields()[0], other.strings_column)) { + this.strings_column = data().deepCopy(fields()[0].schema(), other.strings_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.maybe_ints_column)) { + this.maybe_ints_column = data().deepCopy(fields()[1].schema(), other.maybe_ints_column); + fieldSetFlags()[1] = true; + } + } + + /** Creates a Builder by copying an existing AvroNonNullableArrays instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.SCHEMA$); + if (isValidValue(fields()[0], other.strings_column)) { + this.strings_column = data().deepCopy(fields()[0].schema(), other.strings_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.maybe_ints_column)) { + this.maybe_ints_column = data().deepCopy(fields()[1].schema(), other.maybe_ints_column); + fieldSetFlags()[1] = true; + } + } + + /** Gets the value of the 'strings_column' field */ + public java.util.List getStringsColumn() { + return strings_column; + } + + /** Sets the value of the 'strings_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder setStringsColumn(java.util.List value) { + validate(fields()[0], value); + this.strings_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'strings_column' field has been set */ + public boolean hasStringsColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'strings_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder clearStringsColumn() { + strings_column = null; + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'maybe_ints_column' field */ + public java.util.List getMaybeIntsColumn() { + return maybe_ints_column; + } + + /** Sets the value of the 'maybe_ints_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder setMaybeIntsColumn(java.util.List value) { + validate(fields()[1], value); + this.maybe_ints_column = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'maybe_ints_column' field has been set */ + public boolean hasMaybeIntsColumn() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'maybe_ints_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroNonNullableArrays.Builder clearMaybeIntsColumn() { + maybe_ints_column = null; + fieldSetFlags()[1] = false; + return this; + } + + @Override + public AvroNonNullableArrays build() { + try { + AvroNonNullableArrays record = new AvroNonNullableArrays(); + record.strings_column = fieldSetFlags()[0] ? this.strings_column : (java.util.List) defaultValue(fields()[0]); + record.maybe_ints_column = fieldSetFlags()[1] ? this.maybe_ints_column : (java.util.List) defaultValue(fields()[1]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroOptionalPrimitives.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroOptionalPrimitives.java new file mode 100644 index 0000000000000..e4d1ead8dd15f --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroOptionalPrimitives.java @@ -0,0 +1,466 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class AvroOptionalPrimitives extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"AvroOptionalPrimitives\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public java.lang.Boolean maybe_bool_column; + @Deprecated public java.lang.Integer maybe_int_column; + @Deprecated public java.lang.Long maybe_long_column; + @Deprecated public java.lang.Float maybe_float_column; + @Deprecated public java.lang.Double maybe_double_column; + @Deprecated public java.nio.ByteBuffer maybe_binary_column; + @Deprecated public java.lang.String maybe_string_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public AvroOptionalPrimitives() {} + + /** + * All-args constructor. + */ + public AvroOptionalPrimitives(java.lang.Boolean maybe_bool_column, java.lang.Integer maybe_int_column, java.lang.Long maybe_long_column, java.lang.Float maybe_float_column, java.lang.Double maybe_double_column, java.nio.ByteBuffer maybe_binary_column, java.lang.String maybe_string_column) { + this.maybe_bool_column = maybe_bool_column; + this.maybe_int_column = maybe_int_column; + this.maybe_long_column = maybe_long_column; + this.maybe_float_column = maybe_float_column; + this.maybe_double_column = maybe_double_column; + this.maybe_binary_column = maybe_binary_column; + this.maybe_string_column = maybe_string_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return maybe_bool_column; + case 1: return maybe_int_column; + case 2: return maybe_long_column; + case 3: return maybe_float_column; + case 4: return maybe_double_column; + case 5: return maybe_binary_column; + case 6: return maybe_string_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: maybe_bool_column = (java.lang.Boolean)value$; break; + case 1: maybe_int_column = (java.lang.Integer)value$; break; + case 2: maybe_long_column = (java.lang.Long)value$; break; + case 3: maybe_float_column = (java.lang.Float)value$; break; + case 4: maybe_double_column = (java.lang.Double)value$; break; + case 5: maybe_binary_column = (java.nio.ByteBuffer)value$; break; + case 6: maybe_string_column = (java.lang.String)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'maybe_bool_column' field. + */ + public java.lang.Boolean getMaybeBoolColumn() { + return maybe_bool_column; + } + + /** + * Sets the value of the 'maybe_bool_column' field. + * @param value the value to set. + */ + public void setMaybeBoolColumn(java.lang.Boolean value) { + this.maybe_bool_column = value; + } + + /** + * Gets the value of the 'maybe_int_column' field. + */ + public java.lang.Integer getMaybeIntColumn() { + return maybe_int_column; + } + + /** + * Sets the value of the 'maybe_int_column' field. + * @param value the value to set. + */ + public void setMaybeIntColumn(java.lang.Integer value) { + this.maybe_int_column = value; + } + + /** + * Gets the value of the 'maybe_long_column' field. + */ + public java.lang.Long getMaybeLongColumn() { + return maybe_long_column; + } + + /** + * Sets the value of the 'maybe_long_column' field. + * @param value the value to set. + */ + public void setMaybeLongColumn(java.lang.Long value) { + this.maybe_long_column = value; + } + + /** + * Gets the value of the 'maybe_float_column' field. + */ + public java.lang.Float getMaybeFloatColumn() { + return maybe_float_column; + } + + /** + * Sets the value of the 'maybe_float_column' field. + * @param value the value to set. + */ + public void setMaybeFloatColumn(java.lang.Float value) { + this.maybe_float_column = value; + } + + /** + * Gets the value of the 'maybe_double_column' field. + */ + public java.lang.Double getMaybeDoubleColumn() { + return maybe_double_column; + } + + /** + * Sets the value of the 'maybe_double_column' field. + * @param value the value to set. + */ + public void setMaybeDoubleColumn(java.lang.Double value) { + this.maybe_double_column = value; + } + + /** + * Gets the value of the 'maybe_binary_column' field. + */ + public java.nio.ByteBuffer getMaybeBinaryColumn() { + return maybe_binary_column; + } + + /** + * Sets the value of the 'maybe_binary_column' field. + * @param value the value to set. + */ + public void setMaybeBinaryColumn(java.nio.ByteBuffer value) { + this.maybe_binary_column = value; + } + + /** + * Gets the value of the 'maybe_string_column' field. + */ + public java.lang.String getMaybeStringColumn() { + return maybe_string_column; + } + + /** + * Sets the value of the 'maybe_string_column' field. + * @param value the value to set. + */ + public void setMaybeStringColumn(java.lang.String value) { + this.maybe_string_column = value; + } + + /** Creates a new AvroOptionalPrimitives RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder(); + } + + /** Creates a new AvroOptionalPrimitives RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder(other); + } + + /** Creates a new AvroOptionalPrimitives RecordBuilder by copying an existing AvroOptionalPrimitives instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder(other); + } + + /** + * RecordBuilder for AvroOptionalPrimitives instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private java.lang.Boolean maybe_bool_column; + private java.lang.Integer maybe_int_column; + private java.lang.Long maybe_long_column; + private java.lang.Float maybe_float_column; + private java.lang.Double maybe_double_column; + private java.nio.ByteBuffer maybe_binary_column; + private java.lang.String maybe_string_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder other) { + super(other); + if (isValidValue(fields()[0], other.maybe_bool_column)) { + this.maybe_bool_column = data().deepCopy(fields()[0].schema(), other.maybe_bool_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.maybe_int_column)) { + this.maybe_int_column = data().deepCopy(fields()[1].schema(), other.maybe_int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.maybe_long_column)) { + this.maybe_long_column = data().deepCopy(fields()[2].schema(), other.maybe_long_column); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.maybe_float_column)) { + this.maybe_float_column = data().deepCopy(fields()[3].schema(), other.maybe_float_column); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.maybe_double_column)) { + this.maybe_double_column = data().deepCopy(fields()[4].schema(), other.maybe_double_column); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.maybe_binary_column)) { + this.maybe_binary_column = data().deepCopy(fields()[5].schema(), other.maybe_binary_column); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.maybe_string_column)) { + this.maybe_string_column = data().deepCopy(fields()[6].schema(), other.maybe_string_column); + fieldSetFlags()[6] = true; + } + } + + /** Creates a Builder by copying an existing AvroOptionalPrimitives instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.SCHEMA$); + if (isValidValue(fields()[0], other.maybe_bool_column)) { + this.maybe_bool_column = data().deepCopy(fields()[0].schema(), other.maybe_bool_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.maybe_int_column)) { + this.maybe_int_column = data().deepCopy(fields()[1].schema(), other.maybe_int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.maybe_long_column)) { + this.maybe_long_column = data().deepCopy(fields()[2].schema(), other.maybe_long_column); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.maybe_float_column)) { + this.maybe_float_column = data().deepCopy(fields()[3].schema(), other.maybe_float_column); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.maybe_double_column)) { + this.maybe_double_column = data().deepCopy(fields()[4].schema(), other.maybe_double_column); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.maybe_binary_column)) { + this.maybe_binary_column = data().deepCopy(fields()[5].schema(), other.maybe_binary_column); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.maybe_string_column)) { + this.maybe_string_column = data().deepCopy(fields()[6].schema(), other.maybe_string_column); + fieldSetFlags()[6] = true; + } + } + + /** Gets the value of the 'maybe_bool_column' field */ + public java.lang.Boolean getMaybeBoolColumn() { + return maybe_bool_column; + } + + /** Sets the value of the 'maybe_bool_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeBoolColumn(java.lang.Boolean value) { + validate(fields()[0], value); + this.maybe_bool_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'maybe_bool_column' field has been set */ + public boolean hasMaybeBoolColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'maybe_bool_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeBoolColumn() { + maybe_bool_column = null; + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'maybe_int_column' field */ + public java.lang.Integer getMaybeIntColumn() { + return maybe_int_column; + } + + /** Sets the value of the 'maybe_int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeIntColumn(java.lang.Integer value) { + validate(fields()[1], value); + this.maybe_int_column = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'maybe_int_column' field has been set */ + public boolean hasMaybeIntColumn() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'maybe_int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeIntColumn() { + maybe_int_column = null; + fieldSetFlags()[1] = false; + return this; + } + + /** Gets the value of the 'maybe_long_column' field */ + public java.lang.Long getMaybeLongColumn() { + return maybe_long_column; + } + + /** Sets the value of the 'maybe_long_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeLongColumn(java.lang.Long value) { + validate(fields()[2], value); + this.maybe_long_column = value; + fieldSetFlags()[2] = true; + return this; + } + + /** Checks whether the 'maybe_long_column' field has been set */ + public boolean hasMaybeLongColumn() { + return fieldSetFlags()[2]; + } + + /** Clears the value of the 'maybe_long_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeLongColumn() { + maybe_long_column = null; + fieldSetFlags()[2] = false; + return this; + } + + /** Gets the value of the 'maybe_float_column' field */ + public java.lang.Float getMaybeFloatColumn() { + return maybe_float_column; + } + + /** Sets the value of the 'maybe_float_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeFloatColumn(java.lang.Float value) { + validate(fields()[3], value); + this.maybe_float_column = value; + fieldSetFlags()[3] = true; + return this; + } + + /** Checks whether the 'maybe_float_column' field has been set */ + public boolean hasMaybeFloatColumn() { + return fieldSetFlags()[3]; + } + + /** Clears the value of the 'maybe_float_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeFloatColumn() { + maybe_float_column = null; + fieldSetFlags()[3] = false; + return this; + } + + /** Gets the value of the 'maybe_double_column' field */ + public java.lang.Double getMaybeDoubleColumn() { + return maybe_double_column; + } + + /** Sets the value of the 'maybe_double_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeDoubleColumn(java.lang.Double value) { + validate(fields()[4], value); + this.maybe_double_column = value; + fieldSetFlags()[4] = true; + return this; + } + + /** Checks whether the 'maybe_double_column' field has been set */ + public boolean hasMaybeDoubleColumn() { + return fieldSetFlags()[4]; + } + + /** Clears the value of the 'maybe_double_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeDoubleColumn() { + maybe_double_column = null; + fieldSetFlags()[4] = false; + return this; + } + + /** Gets the value of the 'maybe_binary_column' field */ + public java.nio.ByteBuffer getMaybeBinaryColumn() { + return maybe_binary_column; + } + + /** Sets the value of the 'maybe_binary_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeBinaryColumn(java.nio.ByteBuffer value) { + validate(fields()[5], value); + this.maybe_binary_column = value; + fieldSetFlags()[5] = true; + return this; + } + + /** Checks whether the 'maybe_binary_column' field has been set */ + public boolean hasMaybeBinaryColumn() { + return fieldSetFlags()[5]; + } + + /** Clears the value of the 'maybe_binary_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeBinaryColumn() { + maybe_binary_column = null; + fieldSetFlags()[5] = false; + return this; + } + + /** Gets the value of the 'maybe_string_column' field */ + public java.lang.String getMaybeStringColumn() { + return maybe_string_column; + } + + /** Sets the value of the 'maybe_string_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder setMaybeStringColumn(java.lang.String value) { + validate(fields()[6], value); + this.maybe_string_column = value; + fieldSetFlags()[6] = true; + return this; + } + + /** Checks whether the 'maybe_string_column' field has been set */ + public boolean hasMaybeStringColumn() { + return fieldSetFlags()[6]; + } + + /** Clears the value of the 'maybe_string_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroOptionalPrimitives.Builder clearMaybeStringColumn() { + maybe_string_column = null; + fieldSetFlags()[6] = false; + return this; + } + + @Override + public AvroOptionalPrimitives build() { + try { + AvroOptionalPrimitives record = new AvroOptionalPrimitives(); + record.maybe_bool_column = fieldSetFlags()[0] ? this.maybe_bool_column : (java.lang.Boolean) defaultValue(fields()[0]); + record.maybe_int_column = fieldSetFlags()[1] ? this.maybe_int_column : (java.lang.Integer) defaultValue(fields()[1]); + record.maybe_long_column = fieldSetFlags()[2] ? this.maybe_long_column : (java.lang.Long) defaultValue(fields()[2]); + record.maybe_float_column = fieldSetFlags()[3] ? this.maybe_float_column : (java.lang.Float) defaultValue(fields()[3]); + record.maybe_double_column = fieldSetFlags()[4] ? this.maybe_double_column : (java.lang.Double) defaultValue(fields()[4]); + record.maybe_binary_column = fieldSetFlags()[5] ? this.maybe_binary_column : (java.nio.ByteBuffer) defaultValue(fields()[5]); + record.maybe_string_column = fieldSetFlags()[6] ? this.maybe_string_column : (java.lang.String) defaultValue(fields()[6]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroPrimitives.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroPrimitives.java new file mode 100644 index 0000000000000..1c2afed16781e --- /dev/null +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/AvroPrimitives.java @@ -0,0 +1,461 @@ +/** + * Autogenerated by Avro + * + * DO NOT EDIT DIRECTLY + */ +package org.apache.spark.sql.execution.datasources.parquet.test.avro; +@SuppressWarnings("all") +@org.apache.avro.specific.AvroGenerated +public class AvroPrimitives extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"AvroPrimitives\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}"); + public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } + @Deprecated public boolean bool_column; + @Deprecated public int int_column; + @Deprecated public long long_column; + @Deprecated public float float_column; + @Deprecated public double double_column; + @Deprecated public java.nio.ByteBuffer binary_column; + @Deprecated public java.lang.String string_column; + + /** + * Default constructor. Note that this does not initialize fields + * to their default values from the schema. If that is desired then + * one should use newBuilder(). + */ + public AvroPrimitives() {} + + /** + * All-args constructor. + */ + public AvroPrimitives(java.lang.Boolean bool_column, java.lang.Integer int_column, java.lang.Long long_column, java.lang.Float float_column, java.lang.Double double_column, java.nio.ByteBuffer binary_column, java.lang.String string_column) { + this.bool_column = bool_column; + this.int_column = int_column; + this.long_column = long_column; + this.float_column = float_column; + this.double_column = double_column; + this.binary_column = binary_column; + this.string_column = string_column; + } + + public org.apache.avro.Schema getSchema() { return SCHEMA$; } + // Used by DatumWriter. Applications should not call. + public java.lang.Object get(int field$) { + switch (field$) { + case 0: return bool_column; + case 1: return int_column; + case 2: return long_column; + case 3: return float_column; + case 4: return double_column; + case 5: return binary_column; + case 6: return string_column; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + // Used by DatumReader. Applications should not call. + @SuppressWarnings(value="unchecked") + public void put(int field$, java.lang.Object value$) { + switch (field$) { + case 0: bool_column = (java.lang.Boolean)value$; break; + case 1: int_column = (java.lang.Integer)value$; break; + case 2: long_column = (java.lang.Long)value$; break; + case 3: float_column = (java.lang.Float)value$; break; + case 4: double_column = (java.lang.Double)value$; break; + case 5: binary_column = (java.nio.ByteBuffer)value$; break; + case 6: string_column = (java.lang.String)value$; break; + default: throw new org.apache.avro.AvroRuntimeException("Bad index"); + } + } + + /** + * Gets the value of the 'bool_column' field. + */ + public java.lang.Boolean getBoolColumn() { + return bool_column; + } + + /** + * Sets the value of the 'bool_column' field. + * @param value the value to set. + */ + public void setBoolColumn(java.lang.Boolean value) { + this.bool_column = value; + } + + /** + * Gets the value of the 'int_column' field. + */ + public java.lang.Integer getIntColumn() { + return int_column; + } + + /** + * Sets the value of the 'int_column' field. + * @param value the value to set. + */ + public void setIntColumn(java.lang.Integer value) { + this.int_column = value; + } + + /** + * Gets the value of the 'long_column' field. + */ + public java.lang.Long getLongColumn() { + return long_column; + } + + /** + * Sets the value of the 'long_column' field. + * @param value the value to set. + */ + public void setLongColumn(java.lang.Long value) { + this.long_column = value; + } + + /** + * Gets the value of the 'float_column' field. + */ + public java.lang.Float getFloatColumn() { + return float_column; + } + + /** + * Sets the value of the 'float_column' field. + * @param value the value to set. + */ + public void setFloatColumn(java.lang.Float value) { + this.float_column = value; + } + + /** + * Gets the value of the 'double_column' field. + */ + public java.lang.Double getDoubleColumn() { + return double_column; + } + + /** + * Sets the value of the 'double_column' field. + * @param value the value to set. + */ + public void setDoubleColumn(java.lang.Double value) { + this.double_column = value; + } + + /** + * Gets the value of the 'binary_column' field. + */ + public java.nio.ByteBuffer getBinaryColumn() { + return binary_column; + } + + /** + * Sets the value of the 'binary_column' field. + * @param value the value to set. + */ + public void setBinaryColumn(java.nio.ByteBuffer value) { + this.binary_column = value; + } + + /** + * Gets the value of the 'string_column' field. + */ + public java.lang.String getStringColumn() { + return string_column; + } + + /** + * Sets the value of the 'string_column' field. + * @param value the value to set. + */ + public void setStringColumn(java.lang.String value) { + this.string_column = value; + } + + /** Creates a new AvroPrimitives RecordBuilder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder newBuilder() { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder(); + } + + /** Creates a new AvroPrimitives RecordBuilder by copying an existing Builder */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder(other); + } + + /** Creates a new AvroPrimitives RecordBuilder by copying an existing AvroPrimitives instance */ + public static org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder newBuilder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives other) { + return new org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder(other); + } + + /** + * RecordBuilder for AvroPrimitives instances. + */ + public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase + implements org.apache.avro.data.RecordBuilder { + + private boolean bool_column; + private int int_column; + private long long_column; + private float float_column; + private double double_column; + private java.nio.ByteBuffer binary_column; + private java.lang.String string_column; + + /** Creates a new Builder */ + private Builder() { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.SCHEMA$); + } + + /** Creates a Builder by copying an existing Builder */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder other) { + super(other); + if (isValidValue(fields()[0], other.bool_column)) { + this.bool_column = data().deepCopy(fields()[0].schema(), other.bool_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.int_column)) { + this.int_column = data().deepCopy(fields()[1].schema(), other.int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.long_column)) { + this.long_column = data().deepCopy(fields()[2].schema(), other.long_column); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.float_column)) { + this.float_column = data().deepCopy(fields()[3].schema(), other.float_column); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.double_column)) { + this.double_column = data().deepCopy(fields()[4].schema(), other.double_column); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.binary_column)) { + this.binary_column = data().deepCopy(fields()[5].schema(), other.binary_column); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.string_column)) { + this.string_column = data().deepCopy(fields()[6].schema(), other.string_column); + fieldSetFlags()[6] = true; + } + } + + /** Creates a Builder by copying an existing AvroPrimitives instance */ + private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives other) { + super(org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.SCHEMA$); + if (isValidValue(fields()[0], other.bool_column)) { + this.bool_column = data().deepCopy(fields()[0].schema(), other.bool_column); + fieldSetFlags()[0] = true; + } + if (isValidValue(fields()[1], other.int_column)) { + this.int_column = data().deepCopy(fields()[1].schema(), other.int_column); + fieldSetFlags()[1] = true; + } + if (isValidValue(fields()[2], other.long_column)) { + this.long_column = data().deepCopy(fields()[2].schema(), other.long_column); + fieldSetFlags()[2] = true; + } + if (isValidValue(fields()[3], other.float_column)) { + this.float_column = data().deepCopy(fields()[3].schema(), other.float_column); + fieldSetFlags()[3] = true; + } + if (isValidValue(fields()[4], other.double_column)) { + this.double_column = data().deepCopy(fields()[4].schema(), other.double_column); + fieldSetFlags()[4] = true; + } + if (isValidValue(fields()[5], other.binary_column)) { + this.binary_column = data().deepCopy(fields()[5].schema(), other.binary_column); + fieldSetFlags()[5] = true; + } + if (isValidValue(fields()[6], other.string_column)) { + this.string_column = data().deepCopy(fields()[6].schema(), other.string_column); + fieldSetFlags()[6] = true; + } + } + + /** Gets the value of the 'bool_column' field */ + public java.lang.Boolean getBoolColumn() { + return bool_column; + } + + /** Sets the value of the 'bool_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setBoolColumn(boolean value) { + validate(fields()[0], value); + this.bool_column = value; + fieldSetFlags()[0] = true; + return this; + } + + /** Checks whether the 'bool_column' field has been set */ + public boolean hasBoolColumn() { + return fieldSetFlags()[0]; + } + + /** Clears the value of the 'bool_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearBoolColumn() { + fieldSetFlags()[0] = false; + return this; + } + + /** Gets the value of the 'int_column' field */ + public java.lang.Integer getIntColumn() { + return int_column; + } + + /** Sets the value of the 'int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setIntColumn(int value) { + validate(fields()[1], value); + this.int_column = value; + fieldSetFlags()[1] = true; + return this; + } + + /** Checks whether the 'int_column' field has been set */ + public boolean hasIntColumn() { + return fieldSetFlags()[1]; + } + + /** Clears the value of the 'int_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearIntColumn() { + fieldSetFlags()[1] = false; + return this; + } + + /** Gets the value of the 'long_column' field */ + public java.lang.Long getLongColumn() { + return long_column; + } + + /** Sets the value of the 'long_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setLongColumn(long value) { + validate(fields()[2], value); + this.long_column = value; + fieldSetFlags()[2] = true; + return this; + } + + /** Checks whether the 'long_column' field has been set */ + public boolean hasLongColumn() { + return fieldSetFlags()[2]; + } + + /** Clears the value of the 'long_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearLongColumn() { + fieldSetFlags()[2] = false; + return this; + } + + /** Gets the value of the 'float_column' field */ + public java.lang.Float getFloatColumn() { + return float_column; + } + + /** Sets the value of the 'float_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setFloatColumn(float value) { + validate(fields()[3], value); + this.float_column = value; + fieldSetFlags()[3] = true; + return this; + } + + /** Checks whether the 'float_column' field has been set */ + public boolean hasFloatColumn() { + return fieldSetFlags()[3]; + } + + /** Clears the value of the 'float_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearFloatColumn() { + fieldSetFlags()[3] = false; + return this; + } + + /** Gets the value of the 'double_column' field */ + public java.lang.Double getDoubleColumn() { + return double_column; + } + + /** Sets the value of the 'double_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setDoubleColumn(double value) { + validate(fields()[4], value); + this.double_column = value; + fieldSetFlags()[4] = true; + return this; + } + + /** Checks whether the 'double_column' field has been set */ + public boolean hasDoubleColumn() { + return fieldSetFlags()[4]; + } + + /** Clears the value of the 'double_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearDoubleColumn() { + fieldSetFlags()[4] = false; + return this; + } + + /** Gets the value of the 'binary_column' field */ + public java.nio.ByteBuffer getBinaryColumn() { + return binary_column; + } + + /** Sets the value of the 'binary_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setBinaryColumn(java.nio.ByteBuffer value) { + validate(fields()[5], value); + this.binary_column = value; + fieldSetFlags()[5] = true; + return this; + } + + /** Checks whether the 'binary_column' field has been set */ + public boolean hasBinaryColumn() { + return fieldSetFlags()[5]; + } + + /** Clears the value of the 'binary_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearBinaryColumn() { + binary_column = null; + fieldSetFlags()[5] = false; + return this; + } + + /** Gets the value of the 'string_column' field */ + public java.lang.String getStringColumn() { + return string_column; + } + + /** Sets the value of the 'string_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder setStringColumn(java.lang.String value) { + validate(fields()[6], value); + this.string_column = value; + fieldSetFlags()[6] = true; + return this; + } + + /** Checks whether the 'string_column' field has been set */ + public boolean hasStringColumn() { + return fieldSetFlags()[6]; + } + + /** Clears the value of the 'string_column' field */ + public org.apache.spark.sql.execution.datasources.parquet.test.avro.AvroPrimitives.Builder clearStringColumn() { + string_column = null; + fieldSetFlags()[6] = false; + return this; + } + + @Override + public AvroPrimitives build() { + try { + AvroPrimitives record = new AvroPrimitives(); + record.bool_column = fieldSetFlags()[0] ? this.bool_column : (java.lang.Boolean) defaultValue(fields()[0]); + record.int_column = fieldSetFlags()[1] ? this.int_column : (java.lang.Integer) defaultValue(fields()[1]); + record.long_column = fieldSetFlags()[2] ? this.long_column : (java.lang.Long) defaultValue(fields()[2]); + record.float_column = fieldSetFlags()[3] ? this.float_column : (java.lang.Float) defaultValue(fields()[3]); + record.double_column = fieldSetFlags()[4] ? this.double_column : (java.lang.Double) defaultValue(fields()[4]); + record.binary_column = fieldSetFlags()[5] ? this.binary_column : (java.nio.ByteBuffer) defaultValue(fields()[5]); + record.string_column = fieldSetFlags()[6] ? this.string_column : (java.lang.String) defaultValue(fields()[6]); + return record; + } catch (Exception e) { + throw new org.apache.avro.AvroRuntimeException(e); + } + } + } +} diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java index 2368323cb36b9..28fdc1dfb911c 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/CompatibilityTest.java @@ -8,7 +8,7 @@ @SuppressWarnings("all") @org.apache.avro.specific.AvroGenerated public interface CompatibilityTest { - public static final org.apache.avro.Protocol PROTOCOL = org.apache.avro.Protocol.parse("{\"protocol\":\"CompatibilityTest\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"types\":[{\"type\":\"enum\",\"name\":\"Suit\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]},{\"type\":\"record\",\"name\":\"ParquetEnum\",\"fields\":[{\"name\":\"suit\",\"type\":\"Suit\"}]},{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]},{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}},{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]},{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":\"Nested\"},\"avro.java.string\":\"String\"}}]}],\"messages\":{}}"); + public static final org.apache.avro.Protocol PROTOCOL = org.apache.avro.Protocol.parse("{\"protocol\":\"CompatibilityTest\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"types\":[{\"type\":\"enum\",\"name\":\"Suit\",\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]},{\"type\":\"record\",\"name\":\"ParquetEnum\",\"fields\":[{\"name\":\"suit\",\"type\":\"Suit\"}]},{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]},{\"type\":\"record\",\"name\":\"AvroPrimitives\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]},{\"type\":\"record\",\"name\":\"AvroOptionalPrimitives\",\"fields\":[{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]}]},{\"type\":\"record\",\"name\":\"AvroNonNullableArrays\",\"fields\":[{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"maybe_ints_column\",\"type\":[\"null\",{\"type\":\"array\",\"items\":\"int\"}]}]},{\"type\":\"record\",\"name\":\"AvroArrayOfArray\",\"fields\":[{\"name\":\"int_arrays_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"array\",\"items\":\"int\"}}}]},{\"type\":\"record\",\"name\":\"AvroMapOfArray\",\"fields\":[{\"name\":\"string_to_ints_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":\"int\"},\"avro.java.string\":\"String\"}}]},{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"fields\":[{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":\"Nested\"},\"avro.java.string\":\"String\"}}]}],\"messages\":{}}"); @SuppressWarnings("all") public interface Callback extends CompatibilityTest { diff --git a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java index 681cacbd12c7c..ef12d193f916c 100644 --- a/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java +++ b/sql/core/src/test/gen-java/org/apache/spark/sql/execution/datasources/parquet/test/avro/ParquetAvroCompat.java @@ -7,22 +7,8 @@ @SuppressWarnings("all") @org.apache.avro.specific.AvroGenerated public class ParquetAvroCompat extends org.apache.avro.specific.SpecificRecordBase implements org.apache.avro.specific.SpecificRecord { - public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"bool_column\",\"type\":\"boolean\"},{\"name\":\"int_column\",\"type\":\"int\"},{\"name\":\"long_column\",\"type\":\"long\"},{\"name\":\"float_column\",\"type\":\"float\"},{\"name\":\"double_column\",\"type\":\"double\"},{\"name\":\"binary_column\",\"type\":\"bytes\"},{\"name\":\"string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}},{\"name\":\"maybe_bool_column\",\"type\":[\"null\",\"boolean\"]},{\"name\":\"maybe_int_column\",\"type\":[\"null\",\"int\"]},{\"name\":\"maybe_long_column\",\"type\":[\"null\",\"long\"]},{\"name\":\"maybe_float_column\",\"type\":[\"null\",\"float\"]},{\"name\":\"maybe_double_column\",\"type\":[\"null\",\"double\"]},{\"name\":\"maybe_binary_column\",\"type\":[\"null\",\"bytes\"]},{\"name\":\"maybe_string_column\",\"type\":[\"null\",{\"type\":\"string\",\"avro.java.string\":\"String\"}]},{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}},\"avro.java.string\":\"String\"}}]}"); + public static final org.apache.avro.Schema SCHEMA$ = new org.apache.avro.Schema.Parser().parse("{\"type\":\"record\",\"name\":\"ParquetAvroCompat\",\"namespace\":\"org.apache.spark.sql.execution.datasources.parquet.test.avro\",\"fields\":[{\"name\":\"strings_column\",\"type\":{\"type\":\"array\",\"items\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}},{\"name\":\"string_to_int_column\",\"type\":{\"type\":\"map\",\"values\":\"int\",\"avro.java.string\":\"String\"}},{\"name\":\"complex_column\",\"type\":{\"type\":\"map\",\"values\":{\"type\":\"array\",\"items\":{\"type\":\"record\",\"name\":\"Nested\",\"fields\":[{\"name\":\"nested_ints_column\",\"type\":{\"type\":\"array\",\"items\":\"int\"}},{\"name\":\"nested_string_column\",\"type\":{\"type\":\"string\",\"avro.java.string\":\"String\"}}]}},\"avro.java.string\":\"String\"}}]}"); public static org.apache.avro.Schema getClassSchema() { return SCHEMA$; } - @Deprecated public boolean bool_column; - @Deprecated public int int_column; - @Deprecated public long long_column; - @Deprecated public float float_column; - @Deprecated public double double_column; - @Deprecated public java.nio.ByteBuffer binary_column; - @Deprecated public java.lang.String string_column; - @Deprecated public java.lang.Boolean maybe_bool_column; - @Deprecated public java.lang.Integer maybe_int_column; - @Deprecated public java.lang.Long maybe_long_column; - @Deprecated public java.lang.Float maybe_float_column; - @Deprecated public java.lang.Double maybe_double_column; - @Deprecated public java.nio.ByteBuffer maybe_binary_column; - @Deprecated public java.lang.String maybe_string_column; @Deprecated public java.util.List strings_column; @Deprecated public java.util.Map string_to_int_column; @Deprecated public java.util.Map> complex_column; @@ -37,21 +23,7 @@ public ParquetAvroCompat() {} /** * All-args constructor. */ - public ParquetAvroCompat(java.lang.Boolean bool_column, java.lang.Integer int_column, java.lang.Long long_column, java.lang.Float float_column, java.lang.Double double_column, java.nio.ByteBuffer binary_column, java.lang.String string_column, java.lang.Boolean maybe_bool_column, java.lang.Integer maybe_int_column, java.lang.Long maybe_long_column, java.lang.Float maybe_float_column, java.lang.Double maybe_double_column, java.nio.ByteBuffer maybe_binary_column, java.lang.String maybe_string_column, java.util.List strings_column, java.util.Map string_to_int_column, java.util.Map> complex_column) { - this.bool_column = bool_column; - this.int_column = int_column; - this.long_column = long_column; - this.float_column = float_column; - this.double_column = double_column; - this.binary_column = binary_column; - this.string_column = string_column; - this.maybe_bool_column = maybe_bool_column; - this.maybe_int_column = maybe_int_column; - this.maybe_long_column = maybe_long_column; - this.maybe_float_column = maybe_float_column; - this.maybe_double_column = maybe_double_column; - this.maybe_binary_column = maybe_binary_column; - this.maybe_string_column = maybe_string_column; + public ParquetAvroCompat(java.util.List strings_column, java.util.Map string_to_int_column, java.util.Map> complex_column) { this.strings_column = strings_column; this.string_to_int_column = string_to_int_column; this.complex_column = complex_column; @@ -61,23 +33,9 @@ public ParquetAvroCompat(java.lang.Boolean bool_column, java.lang.Integer int_co // Used by DatumWriter. Applications should not call. public java.lang.Object get(int field$) { switch (field$) { - case 0: return bool_column; - case 1: return int_column; - case 2: return long_column; - case 3: return float_column; - case 4: return double_column; - case 5: return binary_column; - case 6: return string_column; - case 7: return maybe_bool_column; - case 8: return maybe_int_column; - case 9: return maybe_long_column; - case 10: return maybe_float_column; - case 11: return maybe_double_column; - case 12: return maybe_binary_column; - case 13: return maybe_string_column; - case 14: return strings_column; - case 15: return string_to_int_column; - case 16: return complex_column; + case 0: return strings_column; + case 1: return string_to_int_column; + case 2: return complex_column; default: throw new org.apache.avro.AvroRuntimeException("Bad index"); } } @@ -85,237 +43,13 @@ public java.lang.Object get(int field$) { @SuppressWarnings(value="unchecked") public void put(int field$, java.lang.Object value$) { switch (field$) { - case 0: bool_column = (java.lang.Boolean)value$; break; - case 1: int_column = (java.lang.Integer)value$; break; - case 2: long_column = (java.lang.Long)value$; break; - case 3: float_column = (java.lang.Float)value$; break; - case 4: double_column = (java.lang.Double)value$; break; - case 5: binary_column = (java.nio.ByteBuffer)value$; break; - case 6: string_column = (java.lang.String)value$; break; - case 7: maybe_bool_column = (java.lang.Boolean)value$; break; - case 8: maybe_int_column = (java.lang.Integer)value$; break; - case 9: maybe_long_column = (java.lang.Long)value$; break; - case 10: maybe_float_column = (java.lang.Float)value$; break; - case 11: maybe_double_column = (java.lang.Double)value$; break; - case 12: maybe_binary_column = (java.nio.ByteBuffer)value$; break; - case 13: maybe_string_column = (java.lang.String)value$; break; - case 14: strings_column = (java.util.List)value$; break; - case 15: string_to_int_column = (java.util.Map)value$; break; - case 16: complex_column = (java.util.Map>)value$; break; + case 0: strings_column = (java.util.List)value$; break; + case 1: string_to_int_column = (java.util.Map)value$; break; + case 2: complex_column = (java.util.Map>)value$; break; default: throw new org.apache.avro.AvroRuntimeException("Bad index"); } } - /** - * Gets the value of the 'bool_column' field. - */ - public java.lang.Boolean getBoolColumn() { - return bool_column; - } - - /** - * Sets the value of the 'bool_column' field. - * @param value the value to set. - */ - public void setBoolColumn(java.lang.Boolean value) { - this.bool_column = value; - } - - /** - * Gets the value of the 'int_column' field. - */ - public java.lang.Integer getIntColumn() { - return int_column; - } - - /** - * Sets the value of the 'int_column' field. - * @param value the value to set. - */ - public void setIntColumn(java.lang.Integer value) { - this.int_column = value; - } - - /** - * Gets the value of the 'long_column' field. - */ - public java.lang.Long getLongColumn() { - return long_column; - } - - /** - * Sets the value of the 'long_column' field. - * @param value the value to set. - */ - public void setLongColumn(java.lang.Long value) { - this.long_column = value; - } - - /** - * Gets the value of the 'float_column' field. - */ - public java.lang.Float getFloatColumn() { - return float_column; - } - - /** - * Sets the value of the 'float_column' field. - * @param value the value to set. - */ - public void setFloatColumn(java.lang.Float value) { - this.float_column = value; - } - - /** - * Gets the value of the 'double_column' field. - */ - public java.lang.Double getDoubleColumn() { - return double_column; - } - - /** - * Sets the value of the 'double_column' field. - * @param value the value to set. - */ - public void setDoubleColumn(java.lang.Double value) { - this.double_column = value; - } - - /** - * Gets the value of the 'binary_column' field. - */ - public java.nio.ByteBuffer getBinaryColumn() { - return binary_column; - } - - /** - * Sets the value of the 'binary_column' field. - * @param value the value to set. - */ - public void setBinaryColumn(java.nio.ByteBuffer value) { - this.binary_column = value; - } - - /** - * Gets the value of the 'string_column' field. - */ - public java.lang.String getStringColumn() { - return string_column; - } - - /** - * Sets the value of the 'string_column' field. - * @param value the value to set. - */ - public void setStringColumn(java.lang.String value) { - this.string_column = value; - } - - /** - * Gets the value of the 'maybe_bool_column' field. - */ - public java.lang.Boolean getMaybeBoolColumn() { - return maybe_bool_column; - } - - /** - * Sets the value of the 'maybe_bool_column' field. - * @param value the value to set. - */ - public void setMaybeBoolColumn(java.lang.Boolean value) { - this.maybe_bool_column = value; - } - - /** - * Gets the value of the 'maybe_int_column' field. - */ - public java.lang.Integer getMaybeIntColumn() { - return maybe_int_column; - } - - /** - * Sets the value of the 'maybe_int_column' field. - * @param value the value to set. - */ - public void setMaybeIntColumn(java.lang.Integer value) { - this.maybe_int_column = value; - } - - /** - * Gets the value of the 'maybe_long_column' field. - */ - public java.lang.Long getMaybeLongColumn() { - return maybe_long_column; - } - - /** - * Sets the value of the 'maybe_long_column' field. - * @param value the value to set. - */ - public void setMaybeLongColumn(java.lang.Long value) { - this.maybe_long_column = value; - } - - /** - * Gets the value of the 'maybe_float_column' field. - */ - public java.lang.Float getMaybeFloatColumn() { - return maybe_float_column; - } - - /** - * Sets the value of the 'maybe_float_column' field. - * @param value the value to set. - */ - public void setMaybeFloatColumn(java.lang.Float value) { - this.maybe_float_column = value; - } - - /** - * Gets the value of the 'maybe_double_column' field. - */ - public java.lang.Double getMaybeDoubleColumn() { - return maybe_double_column; - } - - /** - * Sets the value of the 'maybe_double_column' field. - * @param value the value to set. - */ - public void setMaybeDoubleColumn(java.lang.Double value) { - this.maybe_double_column = value; - } - - /** - * Gets the value of the 'maybe_binary_column' field. - */ - public java.nio.ByteBuffer getMaybeBinaryColumn() { - return maybe_binary_column; - } - - /** - * Sets the value of the 'maybe_binary_column' field. - * @param value the value to set. - */ - public void setMaybeBinaryColumn(java.nio.ByteBuffer value) { - this.maybe_binary_column = value; - } - - /** - * Gets the value of the 'maybe_string_column' field. - */ - public java.lang.String getMaybeStringColumn() { - return maybe_string_column; - } - - /** - * Sets the value of the 'maybe_string_column' field. - * @param value the value to set. - */ - public void setMaybeStringColumn(java.lang.String value) { - this.maybe_string_column = value; - } - /** * Gets the value of the 'strings_column' field. */ @@ -382,20 +116,6 @@ public static org.apache.spark.sql.execution.datasources.parquet.test.avro.Parqu public static class Builder extends org.apache.avro.specific.SpecificRecordBuilderBase implements org.apache.avro.data.RecordBuilder { - private boolean bool_column; - private int int_column; - private long long_column; - private float float_column; - private double double_column; - private java.nio.ByteBuffer binary_column; - private java.lang.String string_column; - private java.lang.Boolean maybe_bool_column; - private java.lang.Integer maybe_int_column; - private java.lang.Long maybe_long_column; - private java.lang.Float maybe_float_column; - private java.lang.Double maybe_double_column; - private java.nio.ByteBuffer maybe_binary_column; - private java.lang.String maybe_string_column; private java.util.List strings_column; private java.util.Map string_to_int_column; private java.util.Map> complex_column; @@ -408,492 +128,35 @@ private Builder() { /** Creates a Builder by copying an existing Builder */ private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder other) { super(other); - if (isValidValue(fields()[0], other.bool_column)) { - this.bool_column = data().deepCopy(fields()[0].schema(), other.bool_column); + if (isValidValue(fields()[0], other.strings_column)) { + this.strings_column = data().deepCopy(fields()[0].schema(), other.strings_column); fieldSetFlags()[0] = true; } - if (isValidValue(fields()[1], other.int_column)) { - this.int_column = data().deepCopy(fields()[1].schema(), other.int_column); + if (isValidValue(fields()[1], other.string_to_int_column)) { + this.string_to_int_column = data().deepCopy(fields()[1].schema(), other.string_to_int_column); fieldSetFlags()[1] = true; } - if (isValidValue(fields()[2], other.long_column)) { - this.long_column = data().deepCopy(fields()[2].schema(), other.long_column); + if (isValidValue(fields()[2], other.complex_column)) { + this.complex_column = data().deepCopy(fields()[2].schema(), other.complex_column); fieldSetFlags()[2] = true; } - if (isValidValue(fields()[3], other.float_column)) { - this.float_column = data().deepCopy(fields()[3].schema(), other.float_column); - fieldSetFlags()[3] = true; - } - if (isValidValue(fields()[4], other.double_column)) { - this.double_column = data().deepCopy(fields()[4].schema(), other.double_column); - fieldSetFlags()[4] = true; - } - if (isValidValue(fields()[5], other.binary_column)) { - this.binary_column = data().deepCopy(fields()[5].schema(), other.binary_column); - fieldSetFlags()[5] = true; - } - if (isValidValue(fields()[6], other.string_column)) { - this.string_column = data().deepCopy(fields()[6].schema(), other.string_column); - fieldSetFlags()[6] = true; - } - if (isValidValue(fields()[7], other.maybe_bool_column)) { - this.maybe_bool_column = data().deepCopy(fields()[7].schema(), other.maybe_bool_column); - fieldSetFlags()[7] = true; - } - if (isValidValue(fields()[8], other.maybe_int_column)) { - this.maybe_int_column = data().deepCopy(fields()[8].schema(), other.maybe_int_column); - fieldSetFlags()[8] = true; - } - if (isValidValue(fields()[9], other.maybe_long_column)) { - this.maybe_long_column = data().deepCopy(fields()[9].schema(), other.maybe_long_column); - fieldSetFlags()[9] = true; - } - if (isValidValue(fields()[10], other.maybe_float_column)) { - this.maybe_float_column = data().deepCopy(fields()[10].schema(), other.maybe_float_column); - fieldSetFlags()[10] = true; - } - if (isValidValue(fields()[11], other.maybe_double_column)) { - this.maybe_double_column = data().deepCopy(fields()[11].schema(), other.maybe_double_column); - fieldSetFlags()[11] = true; - } - if (isValidValue(fields()[12], other.maybe_binary_column)) { - this.maybe_binary_column = data().deepCopy(fields()[12].schema(), other.maybe_binary_column); - fieldSetFlags()[12] = true; - } - if (isValidValue(fields()[13], other.maybe_string_column)) { - this.maybe_string_column = data().deepCopy(fields()[13].schema(), other.maybe_string_column); - fieldSetFlags()[13] = true; - } - if (isValidValue(fields()[14], other.strings_column)) { - this.strings_column = data().deepCopy(fields()[14].schema(), other.strings_column); - fieldSetFlags()[14] = true; - } - if (isValidValue(fields()[15], other.string_to_int_column)) { - this.string_to_int_column = data().deepCopy(fields()[15].schema(), other.string_to_int_column); - fieldSetFlags()[15] = true; - } - if (isValidValue(fields()[16], other.complex_column)) { - this.complex_column = data().deepCopy(fields()[16].schema(), other.complex_column); - fieldSetFlags()[16] = true; - } } /** Creates a Builder by copying an existing ParquetAvroCompat instance */ private Builder(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat other) { super(org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.SCHEMA$); - if (isValidValue(fields()[0], other.bool_column)) { - this.bool_column = data().deepCopy(fields()[0].schema(), other.bool_column); + if (isValidValue(fields()[0], other.strings_column)) { + this.strings_column = data().deepCopy(fields()[0].schema(), other.strings_column); fieldSetFlags()[0] = true; } - if (isValidValue(fields()[1], other.int_column)) { - this.int_column = data().deepCopy(fields()[1].schema(), other.int_column); + if (isValidValue(fields()[1], other.string_to_int_column)) { + this.string_to_int_column = data().deepCopy(fields()[1].schema(), other.string_to_int_column); fieldSetFlags()[1] = true; } - if (isValidValue(fields()[2], other.long_column)) { - this.long_column = data().deepCopy(fields()[2].schema(), other.long_column); + if (isValidValue(fields()[2], other.complex_column)) { + this.complex_column = data().deepCopy(fields()[2].schema(), other.complex_column); fieldSetFlags()[2] = true; } - if (isValidValue(fields()[3], other.float_column)) { - this.float_column = data().deepCopy(fields()[3].schema(), other.float_column); - fieldSetFlags()[3] = true; - } - if (isValidValue(fields()[4], other.double_column)) { - this.double_column = data().deepCopy(fields()[4].schema(), other.double_column); - fieldSetFlags()[4] = true; - } - if (isValidValue(fields()[5], other.binary_column)) { - this.binary_column = data().deepCopy(fields()[5].schema(), other.binary_column); - fieldSetFlags()[5] = true; - } - if (isValidValue(fields()[6], other.string_column)) { - this.string_column = data().deepCopy(fields()[6].schema(), other.string_column); - fieldSetFlags()[6] = true; - } - if (isValidValue(fields()[7], other.maybe_bool_column)) { - this.maybe_bool_column = data().deepCopy(fields()[7].schema(), other.maybe_bool_column); - fieldSetFlags()[7] = true; - } - if (isValidValue(fields()[8], other.maybe_int_column)) { - this.maybe_int_column = data().deepCopy(fields()[8].schema(), other.maybe_int_column); - fieldSetFlags()[8] = true; - } - if (isValidValue(fields()[9], other.maybe_long_column)) { - this.maybe_long_column = data().deepCopy(fields()[9].schema(), other.maybe_long_column); - fieldSetFlags()[9] = true; - } - if (isValidValue(fields()[10], other.maybe_float_column)) { - this.maybe_float_column = data().deepCopy(fields()[10].schema(), other.maybe_float_column); - fieldSetFlags()[10] = true; - } - if (isValidValue(fields()[11], other.maybe_double_column)) { - this.maybe_double_column = data().deepCopy(fields()[11].schema(), other.maybe_double_column); - fieldSetFlags()[11] = true; - } - if (isValidValue(fields()[12], other.maybe_binary_column)) { - this.maybe_binary_column = data().deepCopy(fields()[12].schema(), other.maybe_binary_column); - fieldSetFlags()[12] = true; - } - if (isValidValue(fields()[13], other.maybe_string_column)) { - this.maybe_string_column = data().deepCopy(fields()[13].schema(), other.maybe_string_column); - fieldSetFlags()[13] = true; - } - if (isValidValue(fields()[14], other.strings_column)) { - this.strings_column = data().deepCopy(fields()[14].schema(), other.strings_column); - fieldSetFlags()[14] = true; - } - if (isValidValue(fields()[15], other.string_to_int_column)) { - this.string_to_int_column = data().deepCopy(fields()[15].schema(), other.string_to_int_column); - fieldSetFlags()[15] = true; - } - if (isValidValue(fields()[16], other.complex_column)) { - this.complex_column = data().deepCopy(fields()[16].schema(), other.complex_column); - fieldSetFlags()[16] = true; - } - } - - /** Gets the value of the 'bool_column' field */ - public java.lang.Boolean getBoolColumn() { - return bool_column; - } - - /** Sets the value of the 'bool_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setBoolColumn(boolean value) { - validate(fields()[0], value); - this.bool_column = value; - fieldSetFlags()[0] = true; - return this; - } - - /** Checks whether the 'bool_column' field has been set */ - public boolean hasBoolColumn() { - return fieldSetFlags()[0]; - } - - /** Clears the value of the 'bool_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearBoolColumn() { - fieldSetFlags()[0] = false; - return this; - } - - /** Gets the value of the 'int_column' field */ - public java.lang.Integer getIntColumn() { - return int_column; - } - - /** Sets the value of the 'int_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setIntColumn(int value) { - validate(fields()[1], value); - this.int_column = value; - fieldSetFlags()[1] = true; - return this; - } - - /** Checks whether the 'int_column' field has been set */ - public boolean hasIntColumn() { - return fieldSetFlags()[1]; - } - - /** Clears the value of the 'int_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearIntColumn() { - fieldSetFlags()[1] = false; - return this; - } - - /** Gets the value of the 'long_column' field */ - public java.lang.Long getLongColumn() { - return long_column; - } - - /** Sets the value of the 'long_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setLongColumn(long value) { - validate(fields()[2], value); - this.long_column = value; - fieldSetFlags()[2] = true; - return this; - } - - /** Checks whether the 'long_column' field has been set */ - public boolean hasLongColumn() { - return fieldSetFlags()[2]; - } - - /** Clears the value of the 'long_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearLongColumn() { - fieldSetFlags()[2] = false; - return this; - } - - /** Gets the value of the 'float_column' field */ - public java.lang.Float getFloatColumn() { - return float_column; - } - - /** Sets the value of the 'float_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setFloatColumn(float value) { - validate(fields()[3], value); - this.float_column = value; - fieldSetFlags()[3] = true; - return this; - } - - /** Checks whether the 'float_column' field has been set */ - public boolean hasFloatColumn() { - return fieldSetFlags()[3]; - } - - /** Clears the value of the 'float_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearFloatColumn() { - fieldSetFlags()[3] = false; - return this; - } - - /** Gets the value of the 'double_column' field */ - public java.lang.Double getDoubleColumn() { - return double_column; - } - - /** Sets the value of the 'double_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setDoubleColumn(double value) { - validate(fields()[4], value); - this.double_column = value; - fieldSetFlags()[4] = true; - return this; - } - - /** Checks whether the 'double_column' field has been set */ - public boolean hasDoubleColumn() { - return fieldSetFlags()[4]; - } - - /** Clears the value of the 'double_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearDoubleColumn() { - fieldSetFlags()[4] = false; - return this; - } - - /** Gets the value of the 'binary_column' field */ - public java.nio.ByteBuffer getBinaryColumn() { - return binary_column; - } - - /** Sets the value of the 'binary_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setBinaryColumn(java.nio.ByteBuffer value) { - validate(fields()[5], value); - this.binary_column = value; - fieldSetFlags()[5] = true; - return this; - } - - /** Checks whether the 'binary_column' field has been set */ - public boolean hasBinaryColumn() { - return fieldSetFlags()[5]; - } - - /** Clears the value of the 'binary_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearBinaryColumn() { - binary_column = null; - fieldSetFlags()[5] = false; - return this; - } - - /** Gets the value of the 'string_column' field */ - public java.lang.String getStringColumn() { - return string_column; - } - - /** Sets the value of the 'string_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setStringColumn(java.lang.String value) { - validate(fields()[6], value); - this.string_column = value; - fieldSetFlags()[6] = true; - return this; - } - - /** Checks whether the 'string_column' field has been set */ - public boolean hasStringColumn() { - return fieldSetFlags()[6]; - } - - /** Clears the value of the 'string_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearStringColumn() { - string_column = null; - fieldSetFlags()[6] = false; - return this; - } - - /** Gets the value of the 'maybe_bool_column' field */ - public java.lang.Boolean getMaybeBoolColumn() { - return maybe_bool_column; - } - - /** Sets the value of the 'maybe_bool_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeBoolColumn(java.lang.Boolean value) { - validate(fields()[7], value); - this.maybe_bool_column = value; - fieldSetFlags()[7] = true; - return this; - } - - /** Checks whether the 'maybe_bool_column' field has been set */ - public boolean hasMaybeBoolColumn() { - return fieldSetFlags()[7]; - } - - /** Clears the value of the 'maybe_bool_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeBoolColumn() { - maybe_bool_column = null; - fieldSetFlags()[7] = false; - return this; - } - - /** Gets the value of the 'maybe_int_column' field */ - public java.lang.Integer getMaybeIntColumn() { - return maybe_int_column; - } - - /** Sets the value of the 'maybe_int_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeIntColumn(java.lang.Integer value) { - validate(fields()[8], value); - this.maybe_int_column = value; - fieldSetFlags()[8] = true; - return this; - } - - /** Checks whether the 'maybe_int_column' field has been set */ - public boolean hasMaybeIntColumn() { - return fieldSetFlags()[8]; - } - - /** Clears the value of the 'maybe_int_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeIntColumn() { - maybe_int_column = null; - fieldSetFlags()[8] = false; - return this; - } - - /** Gets the value of the 'maybe_long_column' field */ - public java.lang.Long getMaybeLongColumn() { - return maybe_long_column; - } - - /** Sets the value of the 'maybe_long_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeLongColumn(java.lang.Long value) { - validate(fields()[9], value); - this.maybe_long_column = value; - fieldSetFlags()[9] = true; - return this; - } - - /** Checks whether the 'maybe_long_column' field has been set */ - public boolean hasMaybeLongColumn() { - return fieldSetFlags()[9]; - } - - /** Clears the value of the 'maybe_long_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeLongColumn() { - maybe_long_column = null; - fieldSetFlags()[9] = false; - return this; - } - - /** Gets the value of the 'maybe_float_column' field */ - public java.lang.Float getMaybeFloatColumn() { - return maybe_float_column; - } - - /** Sets the value of the 'maybe_float_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeFloatColumn(java.lang.Float value) { - validate(fields()[10], value); - this.maybe_float_column = value; - fieldSetFlags()[10] = true; - return this; - } - - /** Checks whether the 'maybe_float_column' field has been set */ - public boolean hasMaybeFloatColumn() { - return fieldSetFlags()[10]; - } - - /** Clears the value of the 'maybe_float_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeFloatColumn() { - maybe_float_column = null; - fieldSetFlags()[10] = false; - return this; - } - - /** Gets the value of the 'maybe_double_column' field */ - public java.lang.Double getMaybeDoubleColumn() { - return maybe_double_column; - } - - /** Sets the value of the 'maybe_double_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeDoubleColumn(java.lang.Double value) { - validate(fields()[11], value); - this.maybe_double_column = value; - fieldSetFlags()[11] = true; - return this; - } - - /** Checks whether the 'maybe_double_column' field has been set */ - public boolean hasMaybeDoubleColumn() { - return fieldSetFlags()[11]; - } - - /** Clears the value of the 'maybe_double_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeDoubleColumn() { - maybe_double_column = null; - fieldSetFlags()[11] = false; - return this; - } - - /** Gets the value of the 'maybe_binary_column' field */ - public java.nio.ByteBuffer getMaybeBinaryColumn() { - return maybe_binary_column; - } - - /** Sets the value of the 'maybe_binary_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeBinaryColumn(java.nio.ByteBuffer value) { - validate(fields()[12], value); - this.maybe_binary_column = value; - fieldSetFlags()[12] = true; - return this; - } - - /** Checks whether the 'maybe_binary_column' field has been set */ - public boolean hasMaybeBinaryColumn() { - return fieldSetFlags()[12]; - } - - /** Clears the value of the 'maybe_binary_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeBinaryColumn() { - maybe_binary_column = null; - fieldSetFlags()[12] = false; - return this; - } - - /** Gets the value of the 'maybe_string_column' field */ - public java.lang.String getMaybeStringColumn() { - return maybe_string_column; - } - - /** Sets the value of the 'maybe_string_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setMaybeStringColumn(java.lang.String value) { - validate(fields()[13], value); - this.maybe_string_column = value; - fieldSetFlags()[13] = true; - return this; - } - - /** Checks whether the 'maybe_string_column' field has been set */ - public boolean hasMaybeStringColumn() { - return fieldSetFlags()[13]; - } - - /** Clears the value of the 'maybe_string_column' field */ - public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearMaybeStringColumn() { - maybe_string_column = null; - fieldSetFlags()[13] = false; - return this; } /** Gets the value of the 'strings_column' field */ @@ -903,21 +166,21 @@ public java.util.List getStringsColumn() { /** Sets the value of the 'strings_column' field */ public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setStringsColumn(java.util.List value) { - validate(fields()[14], value); + validate(fields()[0], value); this.strings_column = value; - fieldSetFlags()[14] = true; + fieldSetFlags()[0] = true; return this; } /** Checks whether the 'strings_column' field has been set */ public boolean hasStringsColumn() { - return fieldSetFlags()[14]; + return fieldSetFlags()[0]; } /** Clears the value of the 'strings_column' field */ public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearStringsColumn() { strings_column = null; - fieldSetFlags()[14] = false; + fieldSetFlags()[0] = false; return this; } @@ -928,21 +191,21 @@ public java.util.Map getStringToIntColumn() /** Sets the value of the 'string_to_int_column' field */ public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder setStringToIntColumn(java.util.Map value) { - validate(fields()[15], value); + validate(fields()[1], value); this.string_to_int_column = value; - fieldSetFlags()[15] = true; + fieldSetFlags()[1] = true; return this; } /** Checks whether the 'string_to_int_column' field has been set */ public boolean hasStringToIntColumn() { - return fieldSetFlags()[15]; + return fieldSetFlags()[1]; } /** Clears the value of the 'string_to_int_column' field */ public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearStringToIntColumn() { string_to_int_column = null; - fieldSetFlags()[15] = false; + fieldSetFlags()[1] = false; return this; } @@ -953,21 +216,21 @@ public java.util.Map> value) { - validate(fields()[16], value); + validate(fields()[2], value); this.complex_column = value; - fieldSetFlags()[16] = true; + fieldSetFlags()[2] = true; return this; } /** Checks whether the 'complex_column' field has been set */ public boolean hasComplexColumn() { - return fieldSetFlags()[16]; + return fieldSetFlags()[2]; } /** Clears the value of the 'complex_column' field */ public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroCompat.Builder clearComplexColumn() { complex_column = null; - fieldSetFlags()[16] = false; + fieldSetFlags()[2] = false; return this; } @@ -975,23 +238,9 @@ public org.apache.spark.sql.execution.datasources.parquet.test.avro.ParquetAvroC public ParquetAvroCompat build() { try { ParquetAvroCompat record = new ParquetAvroCompat(); - record.bool_column = fieldSetFlags()[0] ? this.bool_column : (java.lang.Boolean) defaultValue(fields()[0]); - record.int_column = fieldSetFlags()[1] ? this.int_column : (java.lang.Integer) defaultValue(fields()[1]); - record.long_column = fieldSetFlags()[2] ? this.long_column : (java.lang.Long) defaultValue(fields()[2]); - record.float_column = fieldSetFlags()[3] ? this.float_column : (java.lang.Float) defaultValue(fields()[3]); - record.double_column = fieldSetFlags()[4] ? this.double_column : (java.lang.Double) defaultValue(fields()[4]); - record.binary_column = fieldSetFlags()[5] ? this.binary_column : (java.nio.ByteBuffer) defaultValue(fields()[5]); - record.string_column = fieldSetFlags()[6] ? this.string_column : (java.lang.String) defaultValue(fields()[6]); - record.maybe_bool_column = fieldSetFlags()[7] ? this.maybe_bool_column : (java.lang.Boolean) defaultValue(fields()[7]); - record.maybe_int_column = fieldSetFlags()[8] ? this.maybe_int_column : (java.lang.Integer) defaultValue(fields()[8]); - record.maybe_long_column = fieldSetFlags()[9] ? this.maybe_long_column : (java.lang.Long) defaultValue(fields()[9]); - record.maybe_float_column = fieldSetFlags()[10] ? this.maybe_float_column : (java.lang.Float) defaultValue(fields()[10]); - record.maybe_double_column = fieldSetFlags()[11] ? this.maybe_double_column : (java.lang.Double) defaultValue(fields()[11]); - record.maybe_binary_column = fieldSetFlags()[12] ? this.maybe_binary_column : (java.nio.ByteBuffer) defaultValue(fields()[12]); - record.maybe_string_column = fieldSetFlags()[13] ? this.maybe_string_column : (java.lang.String) defaultValue(fields()[13]); - record.strings_column = fieldSetFlags()[14] ? this.strings_column : (java.util.List) defaultValue(fields()[14]); - record.string_to_int_column = fieldSetFlags()[15] ? this.string_to_int_column : (java.util.Map) defaultValue(fields()[15]); - record.complex_column = fieldSetFlags()[16] ? this.complex_column : (java.util.Map>) defaultValue(fields()[16]); + record.strings_column = fieldSetFlags()[0] ? this.strings_column : (java.util.List) defaultValue(fields()[0]); + record.string_to_int_column = fieldSetFlags()[1] ? this.string_to_int_column : (java.util.Map) defaultValue(fields()[1]); + record.complex_column = fieldSetFlags()[2] ? this.complex_column : (java.util.Map>) defaultValue(fields()[2]); return record; } catch (Exception e) { throw new org.apache.avro.AvroRuntimeException(e); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala index 82d40e2b61a10..45db619567a22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution.datasources.parquet import java.nio.ByteBuffer import java.util.{List => JList, Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters.seqAsJavaListConverter +import scala.collection.JavaConverters.mapAsJavaMapConverter import org.apache.avro.Schema import org.apache.avro.generic.IndexedRecord @@ -32,48 +33,196 @@ import org.apache.spark.sql.execution.datasources.parquet.test.avro._ import org.apache.spark.sql.test.SharedSQLContext class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { - import ParquetCompatibilityTest._ - private def withWriter[T <: IndexedRecord] (path: String, schema: Schema) (f: AvroParquetWriter[T] => Unit): Unit = { + logInfo( + s"""Writing Avro records with the following Avro schema into Parquet file: + | + |${schema.toString(true)} + """.stripMargin) + val writer = new AvroParquetWriter[T](new Path(path), schema) try f(writer) finally writer.close() } - test("Read Parquet file generated by parquet-avro") { + test("required primitives") { withTempPath { dir => val path = dir.getCanonicalPath - withWriter[ParquetAvroCompat](path, ParquetAvroCompat.getClassSchema) { writer => - (0 until 10).foreach(i => writer.write(makeParquetAvroCompat(i))) + withWriter[AvroPrimitives](path, AvroPrimitives.getClassSchema) { writer => + (0 until 10).foreach { i => + writer.write( + AvroPrimitives.newBuilder() + .setBoolColumn(i % 2 == 0) + .setIntColumn(i) + .setLongColumn(i.toLong * 10) + .setFloatColumn(i.toFloat + 0.1f) + .setDoubleColumn(i.toDouble + 0.2d) + .setBinaryColumn(ByteBuffer.wrap(s"val_$i".getBytes("UTF-8"))) + .setStringColumn(s"val_$i") + .build()) + } } - logInfo( - s"""Schema of the Parquet file written by parquet-avro: - |${readParquetSchema(path)} - """.stripMargin) + logParquetSchema(path) checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => - def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i) - Row( i % 2 == 0, i, i.toLong * 10, i.toFloat + 0.1f, i.toDouble + 0.2d, - s"val_$i".getBytes, - s"val_$i", + s"val_$i".getBytes("UTF-8"), + s"val_$i") + }) + } + } + + test("optional primitives") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[AvroOptionalPrimitives](path, AvroOptionalPrimitives.getClassSchema) { writer => + (0 until 10).foreach { i => + val record = if (i % 3 == 0) { + AvroOptionalPrimitives.newBuilder() + .setMaybeBoolColumn(null) + .setMaybeIntColumn(null) + .setMaybeLongColumn(null) + .setMaybeFloatColumn(null) + .setMaybeDoubleColumn(null) + .setMaybeBinaryColumn(null) + .setMaybeStringColumn(null) + .build() + } else { + AvroOptionalPrimitives.newBuilder() + .setMaybeBoolColumn(i % 2 == 0) + .setMaybeIntColumn(i) + .setMaybeLongColumn(i.toLong * 10) + .setMaybeFloatColumn(i.toFloat + 0.1f) + .setMaybeDoubleColumn(i.toDouble + 0.2d) + .setMaybeBinaryColumn(ByteBuffer.wrap(s"val_$i".getBytes("UTF-8"))) + .setMaybeStringColumn(s"val_$i") + .build() + } + + writer.write(record) + } + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + if (i % 3 == 0) { + Row.apply(Seq.fill(7)(null): _*) + } else { + Row( + i % 2 == 0, + i, + i.toLong * 10, + i.toFloat + 0.1f, + i.toDouble + 0.2d, + s"val_$i".getBytes("UTF-8"), + s"val_$i") + } + }) + } + } + + test("non-nullable arrays") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[AvroNonNullableArrays](path, AvroNonNullableArrays.getClassSchema) { writer => + (0 until 10).foreach { i => + val record = { + val builder = + AvroNonNullableArrays.newBuilder() + .setStringsColumn(Seq.tabulate(3)(i => s"val_$i").asJava) + + if (i % 3 == 0) { + builder.setMaybeIntsColumn(null).build() + } else { + builder.setMaybeIntsColumn(Seq.tabulate(3)(Int.box).asJava).build() + } + } + + writer.write(record) + } + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + Row( + Seq.tabulate(3)(i => s"val_$i"), + if (i % 3 == 0) null else Seq.tabulate(3)(identity)) + }) + } + } + + ignore("nullable arrays (parquet-avro 1.7.0 does not properly support this)") { + // TODO Complete this test case after upgrading to parquet-mr 1.8+ + } + + test("SPARK-10136 array of primitive array") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[AvroArrayOfArray](path, AvroArrayOfArray.getClassSchema) { writer => + (0 until 10).foreach { i => + writer.write(AvroArrayOfArray.newBuilder() + .setIntArraysColumn( + Seq.tabulate(3, 3)((i, j) => i * 3 + j: Integer).map(_.asJava).asJava) + .build()) + } + } - nullable(i % 2 == 0: java.lang.Boolean), - nullable(i: Integer), - nullable(i.toLong: java.lang.Long), - nullable(i.toFloat + 0.1f: java.lang.Float), - nullable(i.toDouble + 0.2d: java.lang.Double), - nullable(s"val_$i".getBytes), - nullable(s"val_$i"), + logParquetSchema(path) + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + Row(Seq.tabulate(3, 3)((i, j) => i * 3 + j)) + }) + } + } + + test("map of primitive array") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[AvroMapOfArray](path, AvroMapOfArray.getClassSchema) { writer => + (0 until 10).foreach { i => + writer.write(AvroMapOfArray.newBuilder() + .setStringToIntsColumn( + Seq.tabulate(3) { i => + i.toString -> Seq.tabulate(3)(j => i + j: Integer).asJava + }.toMap.asJava) + .build()) + } + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + Row(Seq.tabulate(3)(i => i.toString -> Seq.tabulate(3)(j => i + j)).toMap) + }) + } + } + + test("various complex types") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withWriter[ParquetAvroCompat](path, ParquetAvroCompat.getClassSchema) { writer => + (0 until 10).foreach(i => writer.write(makeParquetAvroCompat(i))) + } + + logParquetSchema(path) + + checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + Row( Seq.tabulate(3)(n => s"arr_${i + n}"), Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap, Seq.tabulate(3) { n => @@ -86,47 +235,27 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared } def makeParquetAvroCompat(i: Int): ParquetAvroCompat = { - def nullable[T <: AnyRef] = makeNullable[T](i) _ - def makeComplexColumn(i: Int): JMap[String, JList[Nested]] = { - mapAsJavaMap(Seq.tabulate(3) { n => - (i + n).toString -> seqAsJavaList(Seq.tabulate(3) { m => + Seq.tabulate(3) { n => + (i + n).toString -> Seq.tabulate(3) { m => Nested .newBuilder() - .setNestedIntsColumn(seqAsJavaList(Seq.tabulate(3)(j => i + j + m))) + .setNestedIntsColumn(Seq.tabulate(3)(j => i + j + m: Integer).asJava) .setNestedStringColumn(s"val_${i + m}") .build() - }) - }.toMap) + }.asJava + }.toMap.asJava } ParquetAvroCompat .newBuilder() - .setBoolColumn(i % 2 == 0) - .setIntColumn(i) - .setLongColumn(i.toLong * 10) - .setFloatColumn(i.toFloat + 0.1f) - .setDoubleColumn(i.toDouble + 0.2d) - .setBinaryColumn(ByteBuffer.wrap(s"val_$i".getBytes)) - .setStringColumn(s"val_$i") - - .setMaybeBoolColumn(nullable(i % 2 == 0: java.lang.Boolean)) - .setMaybeIntColumn(nullable(i: Integer)) - .setMaybeLongColumn(nullable(i.toLong: java.lang.Long)) - .setMaybeFloatColumn(nullable(i.toFloat + 0.1f: java.lang.Float)) - .setMaybeDoubleColumn(nullable(i.toDouble + 0.2d: java.lang.Double)) - .setMaybeBinaryColumn(nullable(ByteBuffer.wrap(s"val_$i".getBytes))) - .setMaybeStringColumn(nullable(s"val_$i")) - - .setStringsColumn(Seq.tabulate(3)(n => s"arr_${i + n}")) - .setStringToIntColumn( - mapAsJavaMap(Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap)) + .setStringsColumn(Seq.tabulate(3)(n => s"arr_${i + n}").asJava) + .setStringToIntColumn(Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap.asJava) .setComplexColumn(makeComplexColumn(i)) - .build() } - test("SPARK-9407 Don't push down predicates involving Parquet ENUM columns") { + test("SPARK-9407 Push down predicates involving Parquet ENUM columns") { import testImplicits._ withTempPath { dir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala index b3406729fcc5e..d85c564e3e8d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -43,6 +43,13 @@ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with Parq val footers = ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles, true) footers.head.getParquetMetadata.getFileMetaData.getSchema } + + protected def logParquetSchema(path: String): Unit = { + logInfo( + s"""Schema of the Parquet file written by parquet-avro: + |${readParquetSchema(path)} + """.stripMargin) + } } object ParquetCompatibilityTest { From 12de348332108f8c0c5bdad1d4cfac89b952b0f8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 20 Aug 2015 11:31:03 -0700 Subject: [PATCH 023/802] [SPARK-10126] [PROJECT INFRA] Fix typo in release-build.sh which broke snapshot publishing for Scala 2.11 The current `release-build.sh` has a typo which breaks snapshot publication for Scala 2.11. We should change the Scala version to 2.11 and clean before building a 2.11 snapshot. Author: Josh Rosen Closes #8325 from JoshRosen/fix-2.11-snapshots. --- dev/create-release/release-build.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 399c73e7bf6bc..d0b3a54dde1dc 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -225,9 +225,9 @@ if [[ "$1" == "publish-snapshot" ]]; then $MVN -DzincPort=$ZINC_PORT --settings $tmp_settings -DskipTests $PUBLISH_PROFILES \ -Phive-thriftserver deploy - ./dev/change-scala-version.sh 2.10 + ./dev/change-scala-version.sh 2.11 $MVN -DzincPort=$ZINC_PORT -Dscala-2.11 --settings $tmp_settings \ - -DskipTests $PUBLISH_PROFILES deploy + -DskipTests $PUBLISH_PROFILES clean deploy # Clean-up Zinc nailgun process /usr/sbin/lsof -P |grep $ZINC_PORT | grep LISTEN | awk '{ print $2; }' | xargs kill From 907df2fce00d2cbc9fae371344f05f800e0d2726 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 20 Aug 2015 13:51:54 -0700 Subject: [PATCH 024/802] [SQL] [MINOR] remove unnecessary class This class is identical to `org.apache.spark.sql.execution.datasources.jdbc. DefaultSource` and is not needed. Author: Wenchen Fan Closes #8334 from cloud-fan/minor. --- .../execution/datasources/DefaultSource.scala | 64 ------------------- 1 file changed, 64 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DefaultSource.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DefaultSource.scala deleted file mode 100644 index 6e4cc4de7f651..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DefaultSource.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution.datasources - -import java.util.Properties - -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.execution.datasources.jdbc.{JDBCRelation, JDBCPartitioningInfo, DriverRegistry} -import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider} - - -class DefaultSource extends RelationProvider with DataSourceRegister { - - override def shortName(): String = "jdbc" - - /** Returns a new base relation with the given parameters. */ - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String]): BaseRelation = { - val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) - val driver = parameters.getOrElse("driver", null) - val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) - val partitionColumn = parameters.getOrElse("partitionColumn", null) - val lowerBound = parameters.getOrElse("lowerBound", null) - val upperBound = parameters.getOrElse("upperBound", null) - val numPartitions = parameters.getOrElse("numPartitions", null) - - if (driver != null) DriverRegistry.register(driver) - - if (partitionColumn != null - && (lowerBound == null || upperBound == null || numPartitions == null)) { - sys.error("Partitioning incompletely specified") - } - - val partitionInfo = if (partitionColumn == null) { - null - } else { - JDBCPartitioningInfo( - partitionColumn, - lowerBound.toLong, - upperBound.toLong, - numPartitions.toInt) - } - val parts = JDBCRelation.columnPartition(partitionInfo) - val properties = new Properties() // Additional properties that we will pass to getConnection - parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) - JDBCRelation(url, table, parts, properties)(sqlContext) - } -} From 2a3d98aae285aba39786e9809f96de412a130f39 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 20 Aug 2015 14:47:04 -0700 Subject: [PATCH 025/802] [SPARK-10138] [ML] move setters to MultilayerPerceptronClassifier and add Java test suite Otherwise, setters do not return self type. jkbradley avulanov Author: Xiangrui Meng Closes #8342 from mengxr/SPARK-10138. --- .../MultilayerPerceptronClassifier.scala | 54 +++++++------- ...vaMultilayerPerceptronClassifierSuite.java | 74 +++++++++++++++++++ 2 files changed, 101 insertions(+), 27 deletions(-) create mode 100644 mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index ccca4ecc004c3..1e5b0bc4453e4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -42,9 +42,6 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams ParamValidators.arrayLengthGt(1) ) - /** @group setParam */ - def setLayers(value: Array[Int]): this.type = set(layers, value) - /** @group getParam */ final def getLayers: Array[Int] = $(layers) @@ -61,33 +58,9 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams "it is adjusted to the size of this data. Recommended size is between 10 and 1000", ParamValidators.gt(0)) - /** @group setParam */ - def setBlockSize(value: Int): this.type = set(blockSize, value) - /** @group getParam */ final def getBlockSize: Int = $(blockSize) - /** - * Set the maximum number of iterations. - * Default is 100. - * @group setParam - */ - def setMaxIter(value: Int): this.type = set(maxIter, value) - - /** - * Set the convergence tolerance of iterations. - * Smaller value will lead to higher accuracy with the cost of more iterations. - * Default is 1E-4. - * @group setParam - */ - def setTol(value: Double): this.type = set(tol, value) - - /** - * Set the seed for weights initialization. - * @group setParam - */ - def setSeed(value: Long): this.type = set(seed, value) - setDefault(maxIter -> 100, tol -> 1e-4, layers -> Array(1, 1), blockSize -> 128) } @@ -136,6 +109,33 @@ class MultilayerPerceptronClassifier(override val uid: String) def this() = this(Identifiable.randomUID("mlpc")) + /** @group setParam */ + def setLayers(value: Array[Int]): this.type = set(layers, value) + + /** @group setParam */ + def setBlockSize(value: Int): this.type = set(blockSize, value) + + /** + * Set the maximum number of iterations. + * Default is 100. + * @group setParam + */ + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** + * Set the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-4. + * @group setParam + */ + def setTol(value: Double): this.type = set(tol, value) + + /** + * Set the seed for weights initialization. + * @group setParam + */ + def setSeed(value: Long): this.type = set(seed, value) + override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra) /** diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java new file mode 100644 index 0000000000000..ec6b4bf3c0f8c --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; +import java.util.Arrays; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; + +public class JavaMultilayerPerceptronClassifierSuite implements Serializable { + + private transient JavaSparkContext jsc; + private transient SQLContext sqlContext; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); + sqlContext = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + sqlContext = null; + } + + @Test + public void testMLPC() { + DataFrame dataFrame = sqlContext.createDataFrame( + jsc.parallelize(Arrays.asList( + new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + new LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), + new LabeledPoint(0.0, Vectors.dense(1.0, 1.0)))), + LabeledPoint.class); + MultilayerPerceptronClassifier mlpc = new MultilayerPerceptronClassifier() + .setLayers(new int[] {2, 5, 2}) + .setBlockSize(1) + .setSeed(11L) + .setMaxIter(100); + MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame); + DataFrame result = model.transform(dataFrame); + Row[] predictionAndLabels = result.select("prediction", "label").collect(); + for (Row r: predictionAndLabels) { + Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1)); + } + } +} From 7cfc0750e14f2c1b3847e4720cc02150253525a9 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Thu, 20 Aug 2015 14:56:08 -0700 Subject: [PATCH 026/802] [SPARK-10108] Add since tags to mllib.feature Author: MechCoder Closes #8309 from MechCoder/tags_feature. --- .../spark/mllib/feature/ChiSqSelector.scala | 12 +++++++++--- .../mllib/feature/ElementwiseProduct.scala | 4 +++- .../spark/mllib/feature/HashingTF.scala | 11 ++++++++++- .../org/apache/spark/mllib/feature/IDF.scala | 8 +++++++- .../spark/mllib/feature/Normalizer.scala | 5 ++++- .../org/apache/spark/mllib/feature/PCA.scala | 9 ++++++++- .../spark/mllib/feature/StandardScaler.scala | 13 ++++++++++++- .../mllib/feature/VectorTransformer.scala | 6 +++++- .../apache/spark/mllib/feature/Word2Vec.scala | 19 ++++++++++++++++++- 9 files changed, 76 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 5f8c1dea237b4..fdd974d7a391e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.feature import scala.collection.mutable.ArrayBuilder -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.Statistics @@ -31,8 +31,10 @@ import org.apache.spark.rdd.RDD * * @param selectedFeatures list of indices to select (filter). Must be ordered asc */ +@Since("1.3.0") @Experimental -class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransformer { +class ChiSqSelectorModel ( + @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer { require(isSorted(selectedFeatures), "Array has to be sorted asc") @@ -52,6 +54,7 @@ class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransf * @param vector vector to be transformed. * @return transformed vector. */ + @Since("1.3.0") override def transform(vector: Vector): Vector = { compress(vector, selectedFeatures) } @@ -107,8 +110,10 @@ class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransf * @param numTopFeatures number of features that selector will select * (ordered by statistic value descending) */ +@Since("1.3.0") @Experimental -class ChiSqSelector (val numTopFeatures: Int) extends Serializable { +class ChiSqSelector ( + @Since("1.3.0") val numTopFeatures: Int) extends Serializable { /** * Returns a ChiSquared feature selector. @@ -117,6 +122,7 @@ class ChiSqSelector (val numTopFeatures: Int) extends Serializable { * Real-valued features will be treated as categorical for each distinct value. * Apply feature discretizer before using this function. */ + @Since("1.3.0") def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = { val indices = Statistics.chiSqTest(data) .zipWithIndex.sortBy { case (res, _) => -res.statistic } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala index d67fe6c3ee4f8..33e2d17bb472e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg._ /** @@ -27,6 +27,7 @@ import org.apache.spark.mllib.linalg._ * multiplier. * @param scalingVec The values used to scale the reference vector's individual components. */ +@Since("1.4.0") @Experimental class ElementwiseProduct(val scalingVec: Vector) extends VectorTransformer { @@ -36,6 +37,7 @@ class ElementwiseProduct(val scalingVec: Vector) extends VectorTransformer { * @param vector vector to be transformed. * @return transformed vector. */ + @Since("1.4.0") override def transform(vector: Vector): Vector = { require(vector.size == scalingVec.size, s"vector sizes do not match: Expected ${scalingVec.size} but found ${vector.size}") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala index c53475818395f..e47d524b61623 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -22,7 +22,7 @@ import java.lang.{Iterable => JavaIterable} import scala.collection.JavaConverters._ import scala.collection.mutable -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.rdd.RDD @@ -34,19 +34,25 @@ import org.apache.spark.util.Utils * * @param numFeatures number of features (default: 2^20^) */ +@Since("1.1.0") @Experimental class HashingTF(val numFeatures: Int) extends Serializable { + /** + */ + @Since("1.1.0") def this() = this(1 << 20) /** * Returns the index of the input term. */ + @Since("1.1.0") def indexOf(term: Any): Int = Utils.nonNegativeMod(term.##, numFeatures) /** * Transforms the input document into a sparse term frequency vector. */ + @Since("1.1.0") def transform(document: Iterable[_]): Vector = { val termFrequencies = mutable.HashMap.empty[Int, Double] document.foreach { term => @@ -59,6 +65,7 @@ class HashingTF(val numFeatures: Int) extends Serializable { /** * Transforms the input document into a sparse term frequency vector (Java version). */ + @Since("1.1.0") def transform(document: JavaIterable[_]): Vector = { transform(document.asScala) } @@ -66,6 +73,7 @@ class HashingTF(val numFeatures: Int) extends Serializable { /** * Transforms the input document to term frequency vectors. */ + @Since("1.1.0") def transform[D <: Iterable[_]](dataset: RDD[D]): RDD[Vector] = { dataset.map(this.transform) } @@ -73,6 +81,7 @@ class HashingTF(val numFeatures: Int) extends Serializable { /** * Transforms the input document to term frequency vectors (Java version). */ + @Since("1.1.0") def transform[D <: JavaIterable[_]](dataset: JavaRDD[D]): JavaRDD[Vector] = { dataset.rdd.map(this.transform).toJavaRDD() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index 3fab7ea79befc..d5353ddd972e0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.feature import breeze.linalg.{DenseVector => BDV} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD @@ -37,6 +37,7 @@ import org.apache.spark.rdd.RDD * @param minDocFreq minimum of documents in which a term * should appear for filtering */ +@Since("1.1.0") @Experimental class IDF(val minDocFreq: Int) { @@ -48,6 +49,7 @@ class IDF(val minDocFreq: Int) { * Computes the inverse document frequency. * @param dataset an RDD of term frequency vectors */ + @Since("1.1.0") def fit(dataset: RDD[Vector]): IDFModel = { val idf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator( minDocFreq = minDocFreq))( @@ -61,6 +63,7 @@ class IDF(val minDocFreq: Int) { * Computes the inverse document frequency. * @param dataset a JavaRDD of term frequency vectors */ + @Since("1.1.0") def fit(dataset: JavaRDD[Vector]): IDFModel = { fit(dataset.rdd) } @@ -171,6 +174,7 @@ class IDFModel private[spark] (val idf: Vector) extends Serializable { * @param dataset an RDD of term frequency vectors * @return an RDD of TF-IDF vectors */ + @Since("1.1.0") def transform(dataset: RDD[Vector]): RDD[Vector] = { val bcIdf = dataset.context.broadcast(idf) dataset.mapPartitions(iter => iter.map(v => IDFModel.transform(bcIdf.value, v))) @@ -182,6 +186,7 @@ class IDFModel private[spark] (val idf: Vector) extends Serializable { * @param v a term frequency vector * @return a TF-IDF vector */ + @Since("1.3.0") def transform(v: Vector): Vector = IDFModel.transform(idf, v) /** @@ -189,6 +194,7 @@ class IDFModel private[spark] (val idf: Vector) extends Serializable { * @param dataset a JavaRDD of term frequency vectors * @return a JavaRDD of TF-IDF vectors */ + @Since("1.1.0") def transform(dataset: JavaRDD[Vector]): JavaRDD[Vector] = { transform(dataset.rdd).toJavaRDD() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala index 32848e039eb81..0e070257d9fb2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} /** @@ -31,9 +31,11 @@ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors * * @param p Normalization in L^p^ space, p = 2 by default. */ +@Since("1.1.0") @Experimental class Normalizer(p: Double) extends VectorTransformer { + @Since("1.1.0") def this() = this(2) require(p >= 1.0) @@ -44,6 +46,7 @@ class Normalizer(p: Double) extends VectorTransformer { * @param vector vector to be normalized. * @return normalized vector. If the norm of the input is zero, it will return the input vector. */ + @Since("1.1.0") override def transform(vector: Vector): Vector = { val norm = Vectors.norm(vector, p) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala index 2a66263d8b7d6..a48b7bba665d7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.feature +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.distributed.RowMatrix @@ -27,6 +28,7 @@ import org.apache.spark.rdd.RDD * * @param k number of principal components */ +@Since("1.4.0") class PCA(val k: Int) { require(k >= 1, s"PCA requires a number of principal components k >= 1 but was given $k") @@ -35,6 +37,7 @@ class PCA(val k: Int) { * * @param sources source vectors */ + @Since("1.4.0") def fit(sources: RDD[Vector]): PCAModel = { require(k <= sources.first().size, s"source vector size is ${sources.first().size} must be greater than k=$k") @@ -58,7 +61,10 @@ class PCA(val k: Int) { new PCAModel(k, pc) } - /** Java-friendly version of [[fit()]] */ + /** + * Java-friendly version of [[fit()]] + */ + @Since("1.4.0") def fit(sources: JavaRDD[Vector]): PCAModel = fit(sources.rdd) } @@ -76,6 +82,7 @@ class PCAModel private[spark] (val k: Int, val pc: DenseMatrix) extends VectorTr * Vector must be the same length as the source vectors given to [[PCA.fit()]]. * @return transformed vector. Vector will be of length k. */ + @Since("1.4.0") override def transform(vector: Vector): Vector = { vector match { case dv: DenseVector => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index c73b8f258060d..b95d5a899001e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.feature import org.apache.spark.Logging -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD @@ -32,9 +32,11 @@ import org.apache.spark.rdd.RDD * dense output, so this does not work on sparse input and will raise an exception. * @param withStd True by default. Scales the data to unit standard deviation. */ +@Since("1.1.0") @Experimental class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { + @Since("1.1.0") def this() = this(false, true) if (!(withMean || withStd)) { @@ -47,6 +49,7 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { * @param data The data used to compute the mean and variance to build the transformation model. * @return a StandardScalarModel */ + @Since("1.1.0") def fit(data: RDD[Vector]): StandardScalerModel = { // TODO: skip computation if both withMean and withStd are false val summary = data.treeAggregate(new MultivariateOnlineSummarizer)( @@ -69,6 +72,7 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { * @param withStd whether to scale the data to have unit standard deviation * @param withMean whether to center the data before scaling */ +@Since("1.1.0") @Experimental class StandardScalerModel ( val std: Vector, @@ -76,6 +80,9 @@ class StandardScalerModel ( var withStd: Boolean, var withMean: Boolean) extends VectorTransformer { + /** + */ + @Since("1.3.0") def this(std: Vector, mean: Vector) { this(std, mean, withStd = std != null, withMean = mean != null) require(this.withStd || this.withMean, @@ -86,8 +93,10 @@ class StandardScalerModel ( } } + @Since("1.3.0") def this(std: Vector) = this(std, null) + @Since("1.3.0") @DeveloperApi def setWithMean(withMean: Boolean): this.type = { require(!(withMean && this.mean == null), "cannot set withMean to true while mean is null") @@ -95,6 +104,7 @@ class StandardScalerModel ( this } + @Since("1.3.0") @DeveloperApi def setWithStd(withStd: Boolean): this.type = { require(!(withStd && this.std == null), @@ -115,6 +125,7 @@ class StandardScalerModel ( * @return Standardized vector. If the std of a column is zero, it will return default `0.0` * for the column with zero std. */ + @Since("1.1.0") override def transform(vector: Vector): Vector = { require(mean.size == vector.size) if (withMean) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala index 7358c1c84f79c..5778fd1d09254 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.feature -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD * :: DeveloperApi :: * Trait for transformation of a vector */ +@Since("1.1.0") @DeveloperApi trait VectorTransformer extends Serializable { @@ -35,6 +36,7 @@ trait VectorTransformer extends Serializable { * @param vector vector to be transformed. * @return transformed vector. */ + @Since("1.1.0") def transform(vector: Vector): Vector /** @@ -43,6 +45,7 @@ trait VectorTransformer extends Serializable { * @param data RDD[Vector] to be transformed. * @return transformed RDD[Vector]. */ + @Since("1.1.0") def transform(data: RDD[Vector]): RDD[Vector] = { // Later in #1498 , all RDD objects are sent via broadcasting instead of akka. // So it should be no longer necessary to explicitly broadcast `this` object. @@ -55,6 +58,7 @@ trait VectorTransformer extends Serializable { * @param data JavaRDD[Vector] to be transformed. * @return transformed JavaRDD[Vector]. */ + @Since("1.1.0") def transform(data: JavaRDD[Vector]): JavaRDD[Vector] = { transform(data.rdd) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index cbbd2b0c8d060..e6f45ae4b01d5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -32,7 +32,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.Logging import org.apache.spark.SparkContext import org.apache.spark.SparkContext._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, BLAS, DenseVector} import org.apache.spark.mllib.util.{Loader, Saveable} @@ -70,6 +70,7 @@ private case class VocabWord( * and * Distributed Representations of Words and Phrases and their Compositionality. */ +@Since("1.1.0") @Experimental class Word2Vec extends Serializable with Logging { @@ -83,6 +84,7 @@ class Word2Vec extends Serializable with Logging { /** * Sets vector size (default: 100). */ + @Since("1.1.0") def setVectorSize(vectorSize: Int): this.type = { this.vectorSize = vectorSize this @@ -91,6 +93,7 @@ class Word2Vec extends Serializable with Logging { /** * Sets initial learning rate (default: 0.025). */ + @Since("1.1.0") def setLearningRate(learningRate: Double): this.type = { this.learningRate = learningRate this @@ -99,6 +102,7 @@ class Word2Vec extends Serializable with Logging { /** * Sets number of partitions (default: 1). Use a small number for accuracy. */ + @Since("1.1.0") def setNumPartitions(numPartitions: Int): this.type = { require(numPartitions > 0, s"numPartitions must be greater than 0 but got $numPartitions") this.numPartitions = numPartitions @@ -109,6 +113,7 @@ class Word2Vec extends Serializable with Logging { * Sets number of iterations (default: 1), which should be smaller than or equal to number of * partitions. */ + @Since("1.1.0") def setNumIterations(numIterations: Int): this.type = { this.numIterations = numIterations this @@ -117,6 +122,7 @@ class Word2Vec extends Serializable with Logging { /** * Sets random seed (default: a random long integer). */ + @Since("1.1.0") def setSeed(seed: Long): this.type = { this.seed = seed this @@ -126,6 +132,7 @@ class Word2Vec extends Serializable with Logging { * Sets minCount, the minimum number of times a token must appear to be included in the word2vec * model's vocabulary (default: 5). */ + @Since("1.3.0") def setMinCount(minCount: Int): this.type = { this.minCount = minCount this @@ -263,6 +270,7 @@ class Word2Vec extends Serializable with Logging { * @param dataset an RDD of words * @return a Word2VecModel */ + @Since("1.1.0") def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = { val words = dataset.flatMap(x => x) @@ -412,6 +420,7 @@ class Word2Vec extends Serializable with Logging { * @param dataset a JavaRDD of words * @return a Word2VecModel */ + @Since("1.1.0") def fit[S <: JavaIterable[String]](dataset: JavaRDD[S]): Word2VecModel = { fit(dataset.rdd.map(_.asScala)) } @@ -454,6 +463,7 @@ class Word2VecModel private[mllib] ( wordVecNorms } + @Since("1.5.0") def this(model: Map[String, Array[Float]]) = { this(Word2VecModel.buildWordIndex(model), Word2VecModel.buildWordVectors(model)) } @@ -469,6 +479,7 @@ class Word2VecModel private[mllib] ( override protected def formatVersion = "1.0" + @Since("1.4.0") def save(sc: SparkContext, path: String): Unit = { Word2VecModel.SaveLoadV1_0.save(sc, path, getVectors) } @@ -478,6 +489,7 @@ class Word2VecModel private[mllib] ( * @param word a word * @return vector representation of word */ + @Since("1.1.0") def transform(word: String): Vector = { wordIndex.get(word) match { case Some(ind) => @@ -494,6 +506,7 @@ class Word2VecModel private[mllib] ( * @param num number of synonyms to find * @return array of (word, cosineSimilarity) */ + @Since("1.1.0") def findSynonyms(word: String, num: Int): Array[(String, Double)] = { val vector = transform(word) findSynonyms(vector, num) @@ -505,6 +518,7 @@ class Word2VecModel private[mllib] ( * @param num number of synonyms to find * @return array of (word, cosineSimilarity) */ + @Since("1.1.0") def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") // TODO: optimize top-k @@ -534,6 +548,7 @@ class Word2VecModel private[mllib] ( /** * Returns a map of words to their vector representations. */ + @Since("1.2.0") def getVectors: Map[String, Array[Float]] = { wordIndex.map { case (word, ind) => (word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize)) @@ -541,6 +556,7 @@ class Word2VecModel private[mllib] ( } } +@Since("1.4.0") @Experimental object Word2VecModel extends Loader[Word2VecModel] { @@ -600,6 +616,7 @@ object Word2VecModel extends Loader[Word2VecModel] { } } + @Since("1.4.0") override def load(sc: SparkContext, path: String): Word2VecModel = { val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path) From eaafe139f881d6105996373c9b11f2ccd91b5b3e Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 20 Aug 2015 15:01:31 -0700 Subject: [PATCH 027/802] [SPARK-9245] [MLLIB] LDA topic assignments For each (document, term) pair, return top topic. Note that instances of (doc, term) pairs within a document (a.k.a. "tokens") are exchangeable, so we should provide an estimate per document-term, rather than per token. CC: rotationsymmetry mengxr Author: Joseph K. Bradley Closes #8329 from jkbradley/lda-topic-assignments. --- .../spark/mllib/clustering/LDAModel.scala | 51 +++++++++++++++++-- .../spark/mllib/clustering/LDAOptimizer.scala | 2 +- .../spark/mllib/clustering/JavaLDASuite.java | 7 +++ .../spark/mllib/clustering/LDASuite.scala | 21 +++++++- 4 files changed, 74 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index b70e380c0393e..6bc68a4c18b99 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argtopk, normalize, sum} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax, argtopk, normalize, sum} import breeze.numerics.{exp, lgamma} import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats @@ -438,7 +438,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] { Loader.checkSchema[Data](dataFrame.schema) val topics = dataFrame.collect() val vocabSize = topics(0).getAs[Vector](0).size - val k = topics.size + val k = topics.length val brzTopics = BDM.zeros[Double](vocabSize, k) topics.foreach { case Row(vec: Vector, ind: Int) => @@ -610,6 +610,50 @@ class DistributedLDAModel private[clustering] ( } } + /** + * Return the top topic for each (doc, term) pair. I.e., for each document, what is the most + * likely topic generating each term? + * + * @return RDD of (doc ID, assignment of top topic index for each term), + * where the assignment is specified via a pair of zippable arrays + * (term indices, topic indices). Note that terms will be omitted if not present in + * the document. + */ + lazy val topicAssignments: RDD[(Long, Array[Int], Array[Int])] = { + // For reference, compare the below code with the core part of EMLDAOptimizer.next(). + val eta = topicConcentration + val W = vocabSize + val alpha = docConcentration(0) + val N_k = globalTopicTotals + val sendMsg: EdgeContext[TopicCounts, TokenCount, (Array[Int], Array[Int])] => Unit = + (edgeContext) => { + // E-STEP: Compute gamma_{wjk} (smoothed topic distributions). + val scaledTopicDistribution: TopicCounts = + computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) + // For this (doc j, term w), send top topic k to doc vertex. + val topTopic: Int = argmax(scaledTopicDistribution) + val term: Int = index2term(edgeContext.dstId) + edgeContext.sendToSrc((Array(term), Array(topTopic))) + } + val mergeMsg: ((Array[Int], Array[Int]), (Array[Int], Array[Int])) => (Array[Int], Array[Int]) = + (terms_topics0, terms_topics1) => { + (terms_topics0._1 ++ terms_topics1._1, terms_topics0._2 ++ terms_topics1._2) + } + // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts. + val perDocAssignments = + graph.aggregateMessages[(Array[Int], Array[Int])](sendMsg, mergeMsg).filter(isDocumentVertex) + perDocAssignments.map { case (docID: Long, (terms: Array[Int], topics: Array[Int])) => + // TODO: Avoid zip, which is inefficient. + val (sortedTerms, sortedTopics) = terms.zip(topics).sortBy(_._1).unzip + (docID, sortedTerms.toArray, sortedTopics.toArray) + } + } + + /** Java-friendly version of [[topicAssignments]] */ + lazy val javaTopicAssignments: JavaRDD[(java.lang.Long, Array[Int], Array[Int])] = { + topicAssignments.asInstanceOf[RDD[(java.lang.Long, Array[Int], Array[Int])]].toJavaRDD() + } + // TODO // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? @@ -849,10 +893,9 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { val classNameV1_0 = SaveLoadV1_0.thisClassName val model = (loadedClassName, loadedVersion) match { - case (className, "1.0") if className == classNameV1_0 => { + case (className, "1.0") if className == classNameV1_0 => DistributedLDAModel.SaveLoadV1_0.load(sc, path, vocabSize, docConcentration, topicConcentration, iterationTimes.toArray, gammaShape) - } case _ => throw new Exception( s"DistributedLDAModel.load did not recognize model with (className, format version):" + s"($loadedClassName, $loadedVersion). Supported: ($classNameV1_0, 1.0)") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 360241c8081ac..cb517f9689ade 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -167,7 +167,7 @@ final class EMLDAOptimizer extends LDAOptimizer { edgeContext.sendToDst((false, scaledTopicDistribution)) edgeContext.sendToSrc((false, scaledTopicDistribution)) } - // This is a hack to detect whether we could modify the values in-place. + // The Boolean is a hack to detect whether we could modify the values in-place. // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438) val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) = (m0, m1) => { diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index 6e91cde2eabb5..3fea359a3b46c 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -134,6 +134,13 @@ public Boolean call(Tuple2 tuple2) { double[] topicWeights = topTopics._3(); assertEquals(3, topicIndices.length); assertEquals(3, topicWeights.length); + + // Check: topTopicAssignments + Tuple3 topicAssignment = model.javaTopicAssignments().first(); + Long docId2 = topicAssignment._1(); + int[] termIndices2 = topicAssignment._2(); + int[] topicIndices2 = topicAssignment._3(); + assertEquals(termIndices2.length, topicIndices2.length); } @Test diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 99e28499fd316..8a714f9b79e02 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -135,17 +135,34 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } // Top 3 documents per topic - model.topDocumentsPerTopic(3).zip(topDocsByTopicDistributions(3)).foreach {case (t1, t2) => + model.topDocumentsPerTopic(3).zip(topDocsByTopicDistributions(3)).foreach { case (t1, t2) => assert(t1._1 === t2._1) assert(t1._2 === t2._2) } // All documents per topic val q = tinyCorpus.length - model.topDocumentsPerTopic(q).zip(topDocsByTopicDistributions(q)).foreach {case (t1, t2) => + model.topDocumentsPerTopic(q).zip(topDocsByTopicDistributions(q)).foreach { case (t1, t2) => assert(t1._1 === t2._1) assert(t1._2 === t2._2) } + + // Check: topTopicAssignments + // Make sure it assigns a topic to each term appearing in each doc. + val topTopicAssignments: Map[Long, (Array[Int], Array[Int])] = + model.topicAssignments.collect().map(x => x._1 -> (x._2, x._3)).toMap + assert(topTopicAssignments.keys.max < tinyCorpus.length) + tinyCorpus.foreach { case (docID: Long, doc: Vector) => + if (topTopicAssignments.contains(docID)) { + val (inds, vals) = topTopicAssignments(docID) + assert(inds.length === doc.numNonzeros) + // For "term" in actual doc, + // check that it has a topic assigned. + doc.foreachActive((term, wcnt) => assert(wcnt === 0 || inds.contains(term))) + } else { + assert(doc.numNonzeros === 0) + } + } } test("vertex indexing") { From afe9f03fd964d1e8604d02feee8d6970efbe6009 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Thu, 20 Aug 2015 15:10:13 -0700 Subject: [PATCH 028/802] [SPARK-9400] [SQL] codegen for StringLocate This is based on #7779 , thanks to tarekauel . Fix the conflict and nullability. Closes #7779 and #8274 . Author: Tarek Auel Author: Davies Liu Closes #8330 from davies/stringLocate. --- .../expressions/stringExpressions.scala | 28 ++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 3c23f2ecfb57c..b60d318534a41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -409,13 +409,14 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: * in given string after position pos. */ case class StringLocate(substr: Expression, str: Expression, start: Expression) - extends TernaryExpression with ImplicitCastInputTypes with CodegenFallback { + extends TernaryExpression with ImplicitCastInputTypes { def this(substr: Expression, str: Expression) = { this(substr, str, Literal(0)) } override def children: Seq[Expression] = substr :: str :: start :: Nil + override def nullable: Boolean = substr.nullable || str.nullable override def dataType: DataType = IntegerType override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) @@ -441,6 +442,31 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) } } + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val substrGen = substr.gen(ctx) + val strGen = str.gen(ctx) + val startGen = start.gen(ctx) + s""" + int ${ev.primitive} = 0; + boolean ${ev.isNull} = false; + ${startGen.code} + if (!${startGen.isNull}) { + ${substrGen.code} + if (!${substrGen.isNull}) { + ${strGen.code} + if (!${strGen.isNull}) { + ${ev.primitive} = ${strGen.primitive}.indexOf(${substrGen.primitive}, + ${startGen.primitive}) + 1; + } else { + ${ev.isNull} = true; + } + } else { + ${ev.isNull} = true; + } + } + """ + } + override def prettyName: String = "locate" } From cdd9a2bb10e20556003843a0f7aaa33acd55f6d2 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 20 Aug 2015 20:01:13 -0700 Subject: [PATCH 029/802] [SPARK-10140] [DOC] add target fields to @Since so constructors parameters and public fields can be annotated. rxin MechCoder Author: Xiangrui Meng Closes #8344 from mengxr/SPARK-10140.2. --- core/src/main/scala/org/apache/spark/annotation/Since.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/annotation/Since.scala b/core/src/main/scala/org/apache/spark/annotation/Since.scala index fa59393c22476..af483e361e339 100644 --- a/core/src/main/scala/org/apache/spark/annotation/Since.scala +++ b/core/src/main/scala/org/apache/spark/annotation/Since.scala @@ -18,6 +18,7 @@ package org.apache.spark.annotation import scala.annotation.StaticAnnotation +import scala.annotation.meta._ /** * A Scala annotation that specifies the Spark version when a definition was added. @@ -25,4 +26,5 @@ import scala.annotation.StaticAnnotation * hence works for overridden methods that inherit API documentation directly from parents. * The limitation is that it does not show up in the generated Java API documentation. */ +@param @field @getter @setter @beanGetter @beanSetter private[spark] class Since(version: String) extends StaticAnnotation From dcfe0c5cde953b31c5bfeb6e41d1fc9b333241eb Mon Sep 17 00:00:00 2001 From: Alexander Ulanov Date: Thu, 20 Aug 2015 20:02:27 -0700 Subject: [PATCH 030/802] [SPARK-9846] [DOCS] User guide for Multilayer Perceptron Classifier Added user guide for multilayer perceptron classifier: - Simplified description of the multilayer perceptron classifier - Example code for Scala and Java Author: Alexander Ulanov Closes #8262 from avulanov/SPARK-9846-mlpc-docs. --- docs/ml-ann.md | 123 +++++++++++++++++++++++++++++++++++++++++++++++ docs/ml-guide.md | 1 + 2 files changed, 124 insertions(+) create mode 100644 docs/ml-ann.md diff --git a/docs/ml-ann.md b/docs/ml-ann.md new file mode 100644 index 0000000000000..d5ddd92af1e96 --- /dev/null +++ b/docs/ml-ann.md @@ -0,0 +1,123 @@ +--- +layout: global +title: Multilayer perceptron classifier - ML +displayTitle: ML - Multilayer perceptron classifier +--- + + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + + +Multilayer perceptron classifier (MLPC) is a classifier based on the [feedforward artificial neural network](https://en.wikipedia.org/wiki/Feedforward_neural_network). +MLPC consists of multiple layers of nodes. +Each layer is fully connected to the next layer in the network. Nodes in the input layer represent the input data. All other nodes maps inputs to the outputs +by performing linear combination of the inputs with the node's weights `$\wv$` and bias `$\bv$` and applying an activation function. +It can be written in matrix form for MLPC with `$K+1$` layers as follows: +`\[ +\mathrm{y}(\x) = \mathrm{f_K}(...\mathrm{f_2}(\wv_2^T\mathrm{f_1}(\wv_1^T \x+b_1)+b_2)...+b_K) +\]` +Nodes in intermediate layers use sigmoid (logistic) function: +`\[ +\mathrm{f}(z_i) = \frac{1}{1 + e^{-z_i}} +\]` +Nodes in the output layer use softmax function: +`\[ +\mathrm{f}(z_i) = \frac{e^{z_i}}{\sum_{k=1}^N e^{z_k}} +\]` +The number of nodes `$N$` in the output layer corresponds to the number of classes. + +MLPC employes backpropagation for learning the model. We use logistic loss function for optimization and L-BFGS as optimization routine. + +**Examples** + +
+ +
+ +{% highlight scala %} +import org.apache.spark.ml.classification.MultilayerPerceptronClassifier +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.sql.Row + +// Load training data +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt").toDF() +// Split the data into train and test +val splits = data.randomSplit(Array(0.6, 0.4), seed = 1234L) +val train = splits(0) +val test = splits(1) +// specify layers for the neural network: +// input layer of size 4 (features), two intermediate of size 5 and 4 and output of size 3 (classes) +val layers = Array[Int](4, 5, 4, 3) +// create the trainer and set its parameters +val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(128) + .setSeed(1234L) + .setMaxIter(100) +// train the model +val model = trainer.fit(train) +// compute precision on the test set +val result = model.transform(test) +val predictionAndLabels = result.select("prediction", "label") +val evaluator = new MulticlassClassificationEvaluator() + .setMetricName("precision") +println("Precision:" + evaluator.evaluate(predictionAndLabels)) +{% endhighlight %} + +
+ +
+ +{% highlight java %} +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel; +import org.apache.spark.ml.classification.MultilayerPerceptronClassifier; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; + +// Load training data +String path = "data/mllib/sample_multiclass_classification_data.txt"; +JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); +DataFrame dataFrame = sqlContext.createDataFrame(data, LabeledPoint.class); +// Split the data into train and test +DataFrame[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L); +DataFrame train = splits[0]; +DataFrame test = splits[1]; +// specify layers for the neural network: +// input layer of size 4 (features), two intermediate of size 5 and 4 and output of size 3 (classes) +int[] layers = new int[] {4, 5, 4, 3}; +// create the trainer and set its parameters +MultilayerPerceptronClassifier trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(128) + .setSeed(1234L) + .setMaxIter(100); +// train the model +MultilayerPerceptronClassificationModel model = trainer.fit(train); +// compute precision on the test set +DataFrame result = model.transform(test); +DataFrame predictionAndLabels = result.select("prediction", "label"); +MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setMetricName("precision"); +System.out.println("Precision = " + evaluator.evaluate(predictionAndLabels)); +{% endhighlight %} +
+ +
diff --git a/docs/ml-guide.md b/docs/ml-guide.md index c64fff7c0315a..de8fead3529e4 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -179,6 +179,7 @@ There are now several algorithms in the Pipelines API which are not in the lower * [Decision Trees for Classification and Regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) +* [Multilayer perceptron classifier](ml-ann.html) # Code Examples From bb220f6570aa0b95598b30524224a3e82c1effbc Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 21 Aug 2015 01:43:49 -0700 Subject: [PATCH 031/802] [SPARK-10040] [SQL] Use batch insert for JDBC writing JIRA: https://issues.apache.org/jira/browse/SPARK-10040 We should use batch insert instead of single row in JDBC. Author: Liang-Chi Hsieh Closes #8273 from viirya/jdbc-insert-batch. --- .../execution/datasources/jdbc/JdbcUtils.scala | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 2d0e736ee4766..26788b2a4fd69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -88,13 +88,15 @@ object JdbcUtils extends Logging { table: String, iterator: Iterator[Row], rddSchema: StructType, - nullTypes: Array[Int]): Iterator[Byte] = { + nullTypes: Array[Int], + batchSize: Int): Iterator[Byte] = { val conn = getConnection() var committed = false try { conn.setAutoCommit(false) // Everything in the same db transaction. val stmt = insertStatement(conn, table, rddSchema) try { + var rowCount = 0 while (iterator.hasNext) { val row = iterator.next() val numFields = rddSchema.fields.length @@ -122,7 +124,15 @@ object JdbcUtils extends Logging { } i = i + 1 } - stmt.executeUpdate() + stmt.addBatch() + rowCount += 1 + if (rowCount % batchSize == 0) { + stmt.executeBatch() + rowCount = 0 + } + } + if (rowCount > 0) { + stmt.executeBatch() } } finally { stmt.close() @@ -211,8 +221,9 @@ object JdbcUtils extends Logging { val rddSchema = df.schema val driver: String = DriverRegistry.getDriverClassName(url) val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) + val batchSize = properties.getProperty("batchsize", "1000").toInt df.foreachPartition { iterator => - savePartition(getConnection, table, iterator, rddSchema, nullTypes) + savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize) } } From 708036c1de52d674ceff30ac465e1dcedeb8dde8 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 21 Aug 2015 08:41:36 -0500 Subject: [PATCH 032/802] [SPARK-9439] [YARN] External shuffle service robust to NM restarts using leveldb https://issues.apache.org/jira/browse/SPARK-9439 In general, Yarn apps should be robust to NodeManager restarts. However, if you run spark with the external shuffle service on, after a NM restart all shuffles fail, b/c the shuffle service has lost some state with info on each executor. (Note the shuffle data is perfectly fine on disk across a NM restart, the problem is we've lost the small bit of state that lets us *find* those files.) The solution proposed here is that the external shuffle service can write out its state to leveldb (backed by a local file) every time an executor is added. When running with yarn, that file is in the NM's local dir. Whenever the service is started, it looks for that file, and if it exists, it reads the file and re-registers all executors there. Nothing is changed in non-yarn modes with this patch. The service is not given a place to save the state to, so it operates the same as before. This should make it easy to update other cluster managers as well, by just supplying the right file & the equivalent of yarn's `initializeApplication` -- I'm not familiar enough with those modes to know how to do that. Author: Imran Rashid Closes #7943 from squito/leveldb_external_shuffle_service_NM_restart and squashes the following commits: 0d285d3 [Imran Rashid] review feedback 70951d6 [Imran Rashid] Merge branch 'master' into leveldb_external_shuffle_service_NM_restart 5c71c8c [Imran Rashid] save executor to db before registering; style 2499c8c [Imran Rashid] explicit dependency on jackson-annotations 795d28f [Imran Rashid] review feedback 81f80e2 [Imran Rashid] Merge branch 'master' into leveldb_external_shuffle_service_NM_restart 594d520 [Imran Rashid] use json to serialize application executor info 1a7980b [Imran Rashid] version 8267d2a [Imran Rashid] style e9f99e8 [Imran Rashid] cleanup the handling of bad dbs a little 9378ba3 [Imran Rashid] fail gracefully on corrupt leveldb files acedb62 [Imran Rashid] switch to writing out one record per executor 79922b7 [Imran Rashid] rely on yarn to call stopApplication; assorted cleanup 12b6a35 [Imran Rashid] save registered executors when apps are removed; add tests c878fbe [Imran Rashid] better explanation of shuffle service port handling 694934c [Imran Rashid] only open leveldb connection once per service d596410 [Imran Rashid] store executor data in leveldb 59800b7 [Imran Rashid] Files.move in case renaming is unsupported 32fe5ae [Imran Rashid] Merge branch 'master' into external_shuffle_service_NM_restart d7450f0 [Imran Rashid] style f729e2b [Imran Rashid] debugging 4492835 [Imran Rashid] lol, dont use a PrintWriter b/c of scalastyle checks 0a39b98 [Imran Rashid] Merge branch 'master' into external_shuffle_service_NM_restart 55f49fc [Imran Rashid] make sure the service doesnt die if the registered executor file is corrupt; add tests 245db19 [Imran Rashid] style 62586a6 [Imran Rashid] just serialize the whole executors map bdbbf0d [Imran Rashid] comments, remove some unnecessary changes 857331a [Imran Rashid] better tests & comments bb9d1e6 [Imran Rashid] formatting bdc4b32 [Imran Rashid] rename 86e0cb9 [Imran Rashid] for tests, shuffle service finds an open port 23994ff [Imran Rashid] style 7504de8 [Imran Rashid] style a36729c [Imran Rashid] cleanup efb6195 [Imran Rashid] proper unit test, and no longer leak if apps stop during NM restart dd93dc0 [Imran Rashid] test for shuffle service w/ NM restarts d596969 [Imran Rashid] cleanup imports 0e9d69b [Imran Rashid] better names 9eae119 [Imran Rashid] cleanup lots of duplication 1136f44 [Imran Rashid] test needs to have an actual shuffle 0b588bd [Imran Rashid] more fixes ... ad122ef [Imran Rashid] more fixes 5e5a7c3 [Imran Rashid] fix build c69f46b [Imran Rashid] maybe working version, needs tests & cleanup ... bb3ba49 [Imran Rashid] minor cleanup 36127d3 [Imran Rashid] wip b9d2ced [Imran Rashid] incomplete setup for external shuffle service tests --- .../spark/deploy/ExternalShuffleService.scala | 2 +- .../mesos/MesosExternalShuffleService.scala | 2 +- .../apache/spark/storage/BlockManager.scala | 14 +- .../spark/ExternalShuffleServiceSuite.scala | 2 +- network/shuffle/pom.xml | 16 ++ .../shuffle/ExternalShuffleBlockHandler.java | 37 ++- .../shuffle/ExternalShuffleBlockResolver.java | 225 +++++++++++++++-- .../shuffle/protocol/ExecutorShuffleInfo.java | 8 +- .../ExternalShuffleBlockResolverSuite.java | 35 ++- .../shuffle/ExternalShuffleCleanupSuite.java | 9 +- .../ExternalShuffleIntegrationSuite.java | 2 +- .../shuffle/ExternalShuffleSecuritySuite.java | 5 +- .../network/yarn/YarnShuffleService.java | 62 ++++- pom.xml | 5 + yarn/pom.xml | 6 + .../deploy/yarn/BaseYarnClusterSuite.scala | 193 +++++++++++++++ .../spark/deploy/yarn/YarnClusterSuite.scala | 173 +------------ .../yarn/YarnShuffleIntegrationSuite.scala | 109 ++++++++ .../network/shuffle/ShuffleTestAccessor.scala | 71 ++++++ .../yarn/YarnShuffleServiceSuite.scala | 233 ++++++++++++++++++ .../spark/network/yarn/YarnTestAccessor.scala | 37 +++ 21 files changed, 1031 insertions(+), 215 deletions(-) create mode 100644 yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala create mode 100644 yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala create mode 100644 yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala create mode 100644 yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala create mode 100644 yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index 20a9faa1784b7..22ef701d833b2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -53,7 +53,7 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana /** Create a new shuffle block handler. Factored out for subclasses to override. */ protected def newShuffleBlockHandler(conf: TransportConf): ExternalShuffleBlockHandler = { - new ExternalShuffleBlockHandler(conf) + new ExternalShuffleBlockHandler(conf, null) } /** Starts the external shuffle service if the user has configured us to. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala index 061857476a8a0..12337a940a414 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala @@ -34,7 +34,7 @@ import org.apache.spark.network.util.TransportConf * It detects driver termination and calls the cleanup callback to [[ExternalShuffleService]]. */ private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportConf) - extends ExternalShuffleBlockHandler(transportConf) with Logging { + extends ExternalShuffleBlockHandler(transportConf, null) with Logging { // Stores a map of driver socket addresses to app ids private val connectedApps = new mutable.HashMap[SocketAddress, String] diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index eedb27942e841..fefaef0ab82c8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -93,8 +93,17 @@ private[spark] class BlockManager( // Port used by the external shuffle service. In Yarn mode, this may be already be // set through the Hadoop configuration as the server is launched in the Yarn NM. - private val externalShuffleServicePort = - Utils.getSparkOrYarnConfig(conf, "spark.shuffle.service.port", "7337").toInt + private val externalShuffleServicePort = { + val tmpPort = Utils.getSparkOrYarnConfig(conf, "spark.shuffle.service.port", "7337").toInt + if (tmpPort == 0) { + // for testing, we set "spark.shuffle.service.port" to 0 in the yarn config, so yarn finds + // an open port. But we still need to tell our spark apps the right port to use. So + // only if the yarn config has the port set to 0, we prefer the value in the spark config + conf.get("spark.shuffle.service.port").toInt + } else { + tmpPort + } + } // Check that we're not using external shuffle service with consolidated shuffle files. if (externalShuffleServiceEnabled @@ -191,6 +200,7 @@ private[spark] class BlockManager( executorId, blockTransferService.hostName, blockTransferService.port) shuffleServerId = if (externalShuffleServiceEnabled) { + logInfo(s"external shuffle service port = $externalShuffleServicePort") BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) } else { blockManagerId diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index c38d70252add1..e846a72c888c6 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -36,7 +36,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { override def beforeAll() { val transportConf = SparkTransportConf.fromSparkConf(conf, numUsableCores = 2) - rpcHandler = new ExternalShuffleBlockHandler(transportConf) + rpcHandler = new ExternalShuffleBlockHandler(transportConf, null) val transportContext = new TransportContext(transportConf, rpcHandler) server = transportContext.createServer() diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index 532463e96fbb7..3d2edf9d94515 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -43,6 +43,22 @@ ${project.version} + + org.fusesource.leveldbjni + leveldbjni-all + 1.8 + + + + com.fasterxml.jackson.core + jackson-databind + + + + com.fasterxml.jackson.core + jackson-annotations + + org.slf4j diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index db9dc4f17cee9..0df1dd621f6e9 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -17,11 +17,12 @@ package org.apache.spark.network.shuffle; +import java.io.File; +import java.io.IOException; import java.util.List; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Lists; -import org.apache.spark.network.util.TransportConf; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -31,10 +32,10 @@ import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; -import org.apache.spark.network.shuffle.protocol.OpenBlocks; -import org.apache.spark.network.shuffle.protocol.RegisterExecutor; -import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId; +import org.apache.spark.network.shuffle.protocol.*; +import org.apache.spark.network.util.TransportConf; + /** * RPC Handler for a server which can serve shuffle blocks from outside of an Executor process. @@ -46,11 +47,13 @@ public class ExternalShuffleBlockHandler extends RpcHandler { private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockHandler.class); - private final ExternalShuffleBlockResolver blockManager; + @VisibleForTesting + final ExternalShuffleBlockResolver blockManager; private final OneForOneStreamManager streamManager; - public ExternalShuffleBlockHandler(TransportConf conf) { - this(new OneForOneStreamManager(), new ExternalShuffleBlockResolver(conf)); + public ExternalShuffleBlockHandler(TransportConf conf, File registeredExecutorFile) throws IOException { + this(new OneForOneStreamManager(), + new ExternalShuffleBlockResolver(conf, registeredExecutorFile)); } /** Enables mocking out the StreamManager and BlockManager. */ @@ -105,4 +108,22 @@ public StreamManager getStreamManager() { public void applicationRemoved(String appId, boolean cleanupLocalDirs) { blockManager.applicationRemoved(appId, cleanupLocalDirs); } + + /** + * Register an (application, executor) with the given shuffle info. + * + * The "re-" is meant to highlight the intended use of this method -- when this service is + * restarted, this is used to restore the state of executors from before the restart. Normal + * registration will happen via a message handled in receive() + * + * @param appExecId + * @param executorInfo + */ + public void reregisterExecutor(AppExecId appExecId, ExecutorShuffleInfo executorInfo) { + blockManager.registerExecutor(appExecId.appId, appExecId.execId, executorInfo); + } + + public void close() { + blockManager.close(); + } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 022ed88a16480..79beec4429a99 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -17,19 +17,24 @@ package org.apache.spark.network.shuffle; -import java.io.DataInputStream; -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.util.Iterator; -import java.util.Map; +import java.io.*; +import java.util.*; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.Executor; import java.util.concurrent.Executors; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Charsets; import com.google.common.base.Objects; import com.google.common.collect.Maps; +import org.fusesource.leveldbjni.JniDBFactory; +import org.fusesource.leveldbjni.internal.NativeDB; +import org.iq80.leveldb.DB; +import org.iq80.leveldb.DBIterator; +import org.iq80.leveldb.Options; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -52,25 +57,87 @@ public class ExternalShuffleBlockResolver { private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockResolver.class); + private static final ObjectMapper mapper = new ObjectMapper(); + /** + * This a common prefix to the key for each app registration we stick in leveldb, so they + * are easy to find, since leveldb lets you search based on prefix. + */ + private static final String APP_KEY_PREFIX = "AppExecShuffleInfo"; + private static final StoreVersion CURRENT_VERSION = new StoreVersion(1, 0); + // Map containing all registered executors' metadata. - private final ConcurrentMap executors; + @VisibleForTesting + final ConcurrentMap executors; // Single-threaded Java executor used to perform expensive recursive directory deletion. private final Executor directoryCleaner; private final TransportConf conf; - public ExternalShuffleBlockResolver(TransportConf conf) { - this(conf, Executors.newSingleThreadExecutor( + @VisibleForTesting + final File registeredExecutorFile; + @VisibleForTesting + final DB db; + + public ExternalShuffleBlockResolver(TransportConf conf, File registeredExecutorFile) + throws IOException { + this(conf, registeredExecutorFile, Executors.newSingleThreadExecutor( // Add `spark` prefix because it will run in NM in Yarn mode. NettyUtils.createThreadFactory("spark-shuffle-directory-cleaner"))); } // Allows tests to have more control over when directories are cleaned up. @VisibleForTesting - ExternalShuffleBlockResolver(TransportConf conf, Executor directoryCleaner) { + ExternalShuffleBlockResolver( + TransportConf conf, + File registeredExecutorFile, + Executor directoryCleaner) throws IOException { this.conf = conf; - this.executors = Maps.newConcurrentMap(); + this.registeredExecutorFile = registeredExecutorFile; + if (registeredExecutorFile != null) { + Options options = new Options(); + options.createIfMissing(false); + options.logger(new LevelDBLogger()); + DB tmpDb; + try { + tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options); + } catch (NativeDB.DBException e) { + if (e.isNotFound() || e.getMessage().contains(" does not exist ")) { + logger.info("Creating state database at " + registeredExecutorFile); + options.createIfMissing(true); + try { + tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options); + } catch (NativeDB.DBException dbExc) { + throw new IOException("Unable to create state store", dbExc); + } + } else { + // the leveldb file seems to be corrupt somehow. Lets just blow it away and create a new + // one, so we can keep processing new apps + logger.error("error opening leveldb file {}. Creating new file, will not be able to " + + "recover state for existing applications", registeredExecutorFile, e); + if (registeredExecutorFile.isDirectory()) { + for (File f : registeredExecutorFile.listFiles()) { + f.delete(); + } + } + registeredExecutorFile.delete(); + options.createIfMissing(true); + try { + tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options); + } catch (NativeDB.DBException dbExc) { + throw new IOException("Unable to create state store", dbExc); + } + + } + } + // if there is a version mismatch, we throw an exception, which means the service is unusable + checkVersion(tmpDb); + executors = reloadRegisteredExecutors(tmpDb); + db = tmpDb; + } else { + db = null; + executors = Maps.newConcurrentMap(); + } this.directoryCleaner = directoryCleaner; } @@ -81,6 +148,15 @@ public void registerExecutor( ExecutorShuffleInfo executorInfo) { AppExecId fullId = new AppExecId(appId, execId); logger.info("Registered executor {} with {}", fullId, executorInfo); + try { + if (db != null) { + byte[] key = dbAppExecKey(fullId); + byte[] value = mapper.writeValueAsString(executorInfo).getBytes(Charsets.UTF_8); + db.put(key, value); + } + } catch (Exception e) { + logger.error("Error saving registered executors", e); + } executors.put(fullId, executorInfo); } @@ -136,6 +212,13 @@ public void applicationRemoved(String appId, boolean cleanupLocalDirs) { // Only touch executors associated with the appId that was removed. if (appId.equals(fullId.appId)) { it.remove(); + if (db != null) { + try { + db.delete(dbAppExecKey(fullId)); + } catch (IOException e) { + logger.error("Error deleting {} from executor state db", appId, e); + } + } if (cleanupLocalDirs) { logger.info("Cleaning up executor {}'s {} local dirs", fullId, executor.localDirs.length); @@ -220,12 +303,23 @@ static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) return new File(new File(localDir, String.format("%02x", subDirId)), filename); } + void close() { + if (db != null) { + try { + db.close(); + } catch (IOException e) { + logger.error("Exception closing leveldb with registered executors", e); + } + } + } + /** Simply encodes an executor's full ID, which is appId + execId. */ - private static class AppExecId { - final String appId; - final String execId; + public static class AppExecId { + public final String appId; + public final String execId; - private AppExecId(String appId, String execId) { + @JsonCreator + public AppExecId(@JsonProperty("appId") String appId, @JsonProperty("execId") String execId) { this.appId = appId; this.execId = execId; } @@ -252,4 +346,105 @@ public String toString() { .toString(); } } + + private static byte[] dbAppExecKey(AppExecId appExecId) throws IOException { + // we stick a common prefix on all the keys so we can find them in the DB + String appExecJson = mapper.writeValueAsString(appExecId); + String key = (APP_KEY_PREFIX + ";" + appExecJson); + return key.getBytes(Charsets.UTF_8); + } + + private static AppExecId parseDbAppExecKey(String s) throws IOException { + if (!s.startsWith(APP_KEY_PREFIX)) { + throw new IllegalArgumentException("expected a string starting with " + APP_KEY_PREFIX); + } + String json = s.substring(APP_KEY_PREFIX.length() + 1); + AppExecId parsed = mapper.readValue(json, AppExecId.class); + return parsed; + } + + @VisibleForTesting + static ConcurrentMap reloadRegisteredExecutors(DB db) + throws IOException { + ConcurrentMap registeredExecutors = Maps.newConcurrentMap(); + if (db != null) { + DBIterator itr = db.iterator(); + itr.seek(APP_KEY_PREFIX.getBytes(Charsets.UTF_8)); + while (itr.hasNext()) { + Map.Entry e = itr.next(); + String key = new String(e.getKey(), Charsets.UTF_8); + if (!key.startsWith(APP_KEY_PREFIX)) { + break; + } + AppExecId id = parseDbAppExecKey(key); + ExecutorShuffleInfo shuffleInfo = mapper.readValue(e.getValue(), ExecutorShuffleInfo.class); + registeredExecutors.put(id, shuffleInfo); + } + } + return registeredExecutors; + } + + private static class LevelDBLogger implements org.iq80.leveldb.Logger { + private static final Logger LOG = LoggerFactory.getLogger(LevelDBLogger.class); + + @Override + public void log(String message) { + LOG.info(message); + } + } + + /** + * Simple major.minor versioning scheme. Any incompatible changes should be across major + * versions. Minor version differences are allowed -- meaning we should be able to read + * dbs that are either earlier *or* later on the minor version. + */ + private static void checkVersion(DB db) throws IOException { + byte[] bytes = db.get(StoreVersion.KEY); + if (bytes == null) { + storeVersion(db); + } else { + StoreVersion version = mapper.readValue(bytes, StoreVersion.class); + if (version.major != CURRENT_VERSION.major) { + throw new IOException("cannot read state DB with version " + version + ", incompatible " + + "with current version " + CURRENT_VERSION); + } + storeVersion(db); + } + } + + private static void storeVersion(DB db) throws IOException { + db.put(StoreVersion.KEY, mapper.writeValueAsBytes(CURRENT_VERSION)); + } + + + public static class StoreVersion { + + final static byte[] KEY = "StoreVersion".getBytes(Charsets.UTF_8); + + public final int major; + public final int minor; + + @JsonCreator public StoreVersion(@JsonProperty("major") int major, @JsonProperty("minor") int minor) { + this.major = major; + this.minor = minor; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + StoreVersion that = (StoreVersion) o; + + return major == that.major && minor == that.minor; + } + + @Override + public int hashCode() { + int result = major; + result = 31 * result + minor; + return result; + } + } + } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java index cadc8e8369c6a..102d4efb8bf3b 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java @@ -19,6 +19,8 @@ import java.util.Arrays; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; @@ -34,7 +36,11 @@ public class ExecutorShuffleInfo implements Encodable { /** Shuffle manager (SortShuffleManager or HashShuffleManager) that the executor is using. */ public final String shuffleManager; - public ExecutorShuffleInfo(String[] localDirs, int subDirsPerLocalDir, String shuffleManager) { + @JsonCreator + public ExecutorShuffleInfo( + @JsonProperty("localDirs") String[] localDirs, + @JsonProperty("subDirsPerLocalDir") int subDirsPerLocalDir, + @JsonProperty("shuffleManager") String shuffleManager) { this.localDirs = localDirs; this.subDirsPerLocalDir = subDirsPerLocalDir; this.shuffleManager = shuffleManager; diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index d02f4f0fdb682..3c6cb367dea46 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -21,9 +21,12 @@ import java.io.InputStream; import java.io.InputStreamReader; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.io.CharStreams; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; +import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -59,8 +62,8 @@ public static void afterAll() { } @Test - public void testBadRequests() { - ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf); + public void testBadRequests() throws IOException { + ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); // Unregistered executor try { resolver.getBlockData("app0", "exec1", "shuffle_1_1_0"); @@ -91,7 +94,7 @@ public void testBadRequests() { @Test public void testSortShuffleBlocks() throws IOException { - ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf); + ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); resolver.registerExecutor("app0", "exec0", dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager")); @@ -110,7 +113,7 @@ public void testSortShuffleBlocks() throws IOException { @Test public void testHashShuffleBlocks() throws IOException { - ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf); + ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); resolver.registerExecutor("app0", "exec0", dataContext.createExecutorInfo("org.apache.spark.shuffle.hash.HashShuffleManager")); @@ -126,4 +129,28 @@ public void testHashShuffleBlocks() throws IOException { block1Stream.close(); assertEquals(hashBlock1, block1); } + + @Test + public void jsonSerializationOfExecutorRegistration() throws IOException { + ObjectMapper mapper = new ObjectMapper(); + AppExecId appId = new AppExecId("foo", "bar"); + String appIdJson = mapper.writeValueAsString(appId); + AppExecId parsedAppId = mapper.readValue(appIdJson, AppExecId.class); + assertEquals(parsedAppId, appId); + + ExecutorShuffleInfo shuffleInfo = + new ExecutorShuffleInfo(new String[]{"/bippy", "/flippy"}, 7, "hash"); + String shuffleJson = mapper.writeValueAsString(shuffleInfo); + ExecutorShuffleInfo parsedShuffleInfo = + mapper.readValue(shuffleJson, ExecutorShuffleInfo.class); + assertEquals(parsedShuffleInfo, shuffleInfo); + + // Intentionally keep these hard-coded strings in here, to check backwards-compatability. + // its not legacy yet, but keeping this here in case anybody changes it + String legacyAppIdJson = "{\"appId\":\"foo\", \"execId\":\"bar\"}"; + assertEquals(appId, mapper.readValue(legacyAppIdJson, AppExecId.class)); + String legacyShuffleJson = "{\"localDirs\": [\"/bippy\", \"/flippy\"], " + + "\"subDirsPerLocalDir\": 7, \"shuffleManager\": \"hash\"}"; + assertEquals(shuffleInfo, mapper.readValue(legacyShuffleJson, ExecutorShuffleInfo.class)); + } } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java index d9d9c1bf2f17a..2f4f1d0df478b 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java @@ -42,7 +42,7 @@ public void noCleanupAndCleanup() throws IOException { TestShuffleDataContext dataContext = createSomeData(); ExternalShuffleBlockResolver resolver = - new ExternalShuffleBlockResolver(conf, sameThreadExecutor); + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); resolver.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); resolver.applicationRemoved("app", false /* cleanup */); @@ -65,7 +65,8 @@ public void cleanupUsesExecutor() throws IOException { @Override public void execute(Runnable runnable) { cleanupCalled.set(true); } }; - ExternalShuffleBlockResolver manager = new ExternalShuffleBlockResolver(conf, noThreadExecutor); + ExternalShuffleBlockResolver manager = + new ExternalShuffleBlockResolver(conf, null, noThreadExecutor); manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); manager.applicationRemoved("app", true); @@ -83,7 +84,7 @@ public void cleanupMultipleExecutors() throws IOException { TestShuffleDataContext dataContext1 = createSomeData(); ExternalShuffleBlockResolver resolver = - new ExternalShuffleBlockResolver(conf, sameThreadExecutor); + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); resolver.registerExecutor("app", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); resolver.registerExecutor("app", "exec1", dataContext1.createExecutorInfo("shuffleMgr")); @@ -99,7 +100,7 @@ public void cleanupOnlyRemovedApp() throws IOException { TestShuffleDataContext dataContext1 = createSomeData(); ExternalShuffleBlockResolver resolver = - new ExternalShuffleBlockResolver(conf, sameThreadExecutor); + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); resolver.registerExecutor("app-0", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); resolver.registerExecutor("app-1", "exec0", dataContext1.createExecutorInfo("shuffleMgr")); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 39aa49911d9cb..a3f9a38b1aeb9 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -92,7 +92,7 @@ public static void beforeAll() throws IOException { dataContext1.insertHashShuffleData(1, 0, exec1Blocks); conf = new TransportConf(new SystemPropertyConfigProvider()); - handler = new ExternalShuffleBlockHandler(conf); + handler = new ExternalShuffleBlockHandler(conf, null); TransportContext transportContext = new TransportContext(conf, handler); server = transportContext.createServer(); } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index d4ec1956c1e29..aa99efda94948 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -43,8 +43,9 @@ public class ExternalShuffleSecuritySuite { TransportServer server; @Before - public void beforeEach() { - TransportContext context = new TransportContext(conf, new ExternalShuffleBlockHandler(conf)); + public void beforeEach() throws IOException { + TransportContext context = + new TransportContext(conf, new ExternalShuffleBlockHandler(conf, null)); TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, new TestSecretKeyHolder("my-app-id", "secret")); this.server = context.createServer(Arrays.asList(bootstrap)); diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index 463f99ef3352d..11ea7f3fd3cfe 100644 --- a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -17,25 +17,21 @@ package org.apache.spark.network.yarn; +import java.io.File; import java.nio.ByteBuffer; import java.util.List; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Lists; import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.hadoop.yarn.api.records.ContainerId; -import org.apache.hadoop.yarn.server.api.AuxiliaryService; -import org.apache.hadoop.yarn.server.api.ApplicationInitializationContext; -import org.apache.hadoop.yarn.server.api.ApplicationTerminationContext; -import org.apache.hadoop.yarn.server.api.ContainerInitializationContext; -import org.apache.hadoop.yarn.server.api.ContainerTerminationContext; +import org.apache.hadoop.yarn.server.api.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.TransportContext; import org.apache.spark.network.sasl.SaslServerBootstrap; import org.apache.spark.network.sasl.ShuffleSecretManager; -import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; @@ -79,11 +75,26 @@ public class YarnShuffleService extends AuxiliaryService { private TransportServer shuffleServer = null; // Handles registering executors and opening shuffle blocks - private ExternalShuffleBlockHandler blockHandler; + @VisibleForTesting + ExternalShuffleBlockHandler blockHandler; + + // Where to store & reload executor info for recovering state after an NM restart + @VisibleForTesting + File registeredExecutorFile; + + // just for testing when you want to find an open port + @VisibleForTesting + static int boundPort = -1; + + // just for integration tests that want to look at this file -- in general not sensible as + // a static + @VisibleForTesting + static YarnShuffleService instance; public YarnShuffleService() { super("spark_shuffle"); logger.info("Initializing YARN shuffle service for Spark"); + instance = this; } /** @@ -100,11 +111,24 @@ private boolean isAuthenticationEnabled() { */ @Override protected void serviceInit(Configuration conf) { + + // In case this NM was killed while there were running spark applications, we need to restore + // lost state for the existing executors. We look for an existing file in the NM's local dirs. + // If we don't find one, then we choose a file to use to save the state next time. Even if + // an application was stopped while the NM was down, we expect yarn to call stopApplication() + // when it comes back + registeredExecutorFile = + findRegisteredExecutorFile(conf.getStrings("yarn.nodemanager.local-dirs")); + TransportConf transportConf = new TransportConf(new HadoopConfigProvider(conf)); // If authentication is enabled, set up the shuffle server to use a // special RPC handler that filters out unauthenticated fetch requests boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); - blockHandler = new ExternalShuffleBlockHandler(transportConf); + try { + blockHandler = new ExternalShuffleBlockHandler(transportConf, registeredExecutorFile); + } catch (Exception e) { + logger.error("Failed to initialize external shuffle service", e); + } List bootstraps = Lists.newArrayList(); if (authEnabled) { @@ -116,9 +140,13 @@ protected void serviceInit(Configuration conf) { SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT); TransportContext transportContext = new TransportContext(transportConf, blockHandler); shuffleServer = transportContext.createServer(port, bootstraps); + // the port should normally be fixed, but for tests its useful to find an open port + port = shuffleServer.getPort(); + boundPort = port; String authEnabledString = authEnabled ? "enabled" : "not enabled"; logger.info("Started YARN shuffle service for Spark on port {}. " + - "Authentication is {}.", port, authEnabledString); + "Authentication is {}. Registered executor file is {}", port, authEnabledString, + registeredExecutorFile); } @Override @@ -161,6 +189,16 @@ public void stopContainer(ContainerTerminationContext context) { logger.info("Stopping container {}", containerId); } + private File findRegisteredExecutorFile(String[] localDirs) { + for (String dir: localDirs) { + File f = new File(dir, "registeredExecutors.ldb"); + if (f.exists()) { + return f; + } + } + return new File(localDirs[0], "registeredExecutors.ldb"); + } + /** * Close the shuffle server to clean up any associated state. */ @@ -170,6 +208,9 @@ protected void serviceStop() { if (shuffleServer != null) { shuffleServer.close(); } + if (blockHandler != null) { + blockHandler.close(); + } } catch (Exception e) { logger.error("Exception when stopping service", e); } @@ -180,5 +221,4 @@ protected void serviceStop() { public ByteBuffer getMetaData() { return ByteBuffer.allocate(0); } - } diff --git a/pom.xml b/pom.xml index ccfa1ea19b21e..d5945f2546d38 100644 --- a/pom.xml +++ b/pom.xml @@ -655,6 +655,11 @@ jackson-databind ${fasterxml.jackson.version} + + com.fasterxml.jackson.core + jackson-annotations + ${fasterxml.jackson.version} + diff --git a/yarn/pom.xml b/yarn/pom.xml index 15db54e4e7909..f6737695307a2 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -38,6 +38,12 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-network-yarn_${scala.binary.version} + ${project.version} + test + org.apache.spark spark-core_${scala.binary.version} diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala new file mode 100644 index 0000000000000..128e996b71fe5 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.io.{File, FileOutputStream, OutputStreamWriter} +import java.util.Properties +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConversions._ + +import com.google.common.base.Charsets.UTF_8 +import com.google.common.io.Files +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.server.MiniYARNCluster +import org.scalatest.{BeforeAndAfterAll, Matchers} + +import org.apache.spark._ +import org.apache.spark.util.Utils + +abstract class BaseYarnClusterSuite + extends SparkFunSuite with BeforeAndAfterAll with Matchers with Logging { + + // log4j configuration for the YARN containers, so that their output is collected + // by YARN instead of trying to overwrite unit-tests.log. + protected val LOG4J_CONF = """ + |log4j.rootCategory=DEBUG, console + |log4j.appender.console=org.apache.log4j.ConsoleAppender + |log4j.appender.console.target=System.err + |log4j.appender.console.layout=org.apache.log4j.PatternLayout + |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + """.stripMargin + + private var yarnCluster: MiniYARNCluster = _ + protected var tempDir: File = _ + private var fakeSparkJar: File = _ + private var hadoopConfDir: File = _ + private var logConfDir: File = _ + + + def yarnConfig: YarnConfiguration + + override def beforeAll() { + super.beforeAll() + + tempDir = Utils.createTempDir() + logConfDir = new File(tempDir, "log4j") + logConfDir.mkdir() + System.setProperty("SPARK_YARN_MODE", "true") + + val logConfFile = new File(logConfDir, "log4j.properties") + Files.write(LOG4J_CONF, logConfFile, UTF_8) + + yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1) + yarnCluster.init(yarnConfig) + yarnCluster.start() + + // There's a race in MiniYARNCluster in which start() may return before the RM has updated + // its address in the configuration. You can see this in the logs by noticing that when + // MiniYARNCluster prints the address, it still has port "0" assigned, although later the + // test works sometimes: + // + // INFO MiniYARNCluster: MiniYARN ResourceManager address: blah:0 + // + // That log message prints the contents of the RM_ADDRESS config variable. If you check it + // later on, it looks something like this: + // + // INFO YarnClusterSuite: RM address in configuration is blah:42631 + // + // This hack loops for a bit waiting for the port to change, and fails the test if it hasn't + // done so in a timely manner (defined to be 10 seconds). + val config = yarnCluster.getConfig() + val deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(10) + while (config.get(YarnConfiguration.RM_ADDRESS).split(":")(1) == "0") { + if (System.currentTimeMillis() > deadline) { + throw new IllegalStateException("Timed out waiting for RM to come up.") + } + logDebug("RM address still not set in configuration, waiting...") + TimeUnit.MILLISECONDS.sleep(100) + } + + logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}") + + fakeSparkJar = File.createTempFile("sparkJar", null, tempDir) + hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR) + assert(hadoopConfDir.mkdir()) + File.createTempFile("token", ".txt", hadoopConfDir) + } + + override def afterAll() { + yarnCluster.stop() + System.clearProperty("SPARK_YARN_MODE") + super.afterAll() + } + + protected def runSpark( + clientMode: Boolean, + klass: String, + appArgs: Seq[String] = Nil, + sparkArgs: Seq[String] = Nil, + extraClassPath: Seq[String] = Nil, + extraJars: Seq[String] = Nil, + extraConf: Map[String, String] = Map()): Unit = { + val master = if (clientMode) "yarn-client" else "yarn-cluster" + val props = new Properties() + + props.setProperty("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath()) + + val childClasspath = logConfDir.getAbsolutePath() + + File.pathSeparator + + sys.props("java.class.path") + + File.pathSeparator + + extraClassPath.mkString(File.pathSeparator) + props.setProperty("spark.driver.extraClassPath", childClasspath) + props.setProperty("spark.executor.extraClassPath", childClasspath) + + // SPARK-4267: make sure java options are propagated correctly. + props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"") + props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"") + + yarnCluster.getConfig().foreach { e => + props.setProperty("spark.hadoop." + e.getKey(), e.getValue()) + } + + sys.props.foreach { case (k, v) => + if (k.startsWith("spark.")) { + props.setProperty(k, v) + } + } + + extraConf.foreach { case (k, v) => props.setProperty(k, v) } + + val propsFile = File.createTempFile("spark", ".properties", tempDir) + val writer = new OutputStreamWriter(new FileOutputStream(propsFile), UTF_8) + props.store(writer, "Spark properties.") + writer.close() + + val extraJarArgs = if (!extraJars.isEmpty()) Seq("--jars", extraJars.mkString(",")) else Nil + val mainArgs = + if (klass.endsWith(".py")) { + Seq(klass) + } else { + Seq("--class", klass, fakeSparkJar.getAbsolutePath()) + } + val argv = + Seq( + new File(sys.props("spark.test.home"), "bin/spark-submit").getAbsolutePath(), + "--master", master, + "--num-executors", "1", + "--properties-file", propsFile.getAbsolutePath()) ++ + extraJarArgs ++ + sparkArgs ++ + mainArgs ++ + appArgs + + Utils.executeAndGetOutput(argv, + extraEnvironment = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath())) + } + + /** + * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide + * any sort of error when the job process finishes successfully, but the job itself fails. So + * the tests enforce that something is written to a file after everything is ok to indicate + * that the job succeeded. + */ + protected def checkResult(result: File): Unit = { + checkResult(result, "success") + } + + protected def checkResult(result: File, expected: String): Unit = { + val resultString = Files.toString(result, UTF_8) + resultString should be (expected) + } + + protected def mainClassName(klass: Class[_]): String = { + klass.getName().stripSuffix("$") + } + +} diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index eb6e1fd370620..128350b648992 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -17,25 +17,20 @@ package org.apache.spark.deploy.yarn -import java.io.{File, FileOutputStream, OutputStreamWriter} +import java.io.File import java.net.URL -import java.util.Properties -import java.util.concurrent.TimeUnit -import scala.collection.JavaConversions._ import scala.collection.mutable +import scala.collection.JavaConversions._ import com.google.common.base.Charsets.UTF_8 -import com.google.common.io.ByteStreams -import com.google.common.io.Files +import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.server.MiniYARNCluster -import org.scalatest.{BeforeAndAfterAll, Matchers} +import org.scalatest.Matchers import org.apache.spark._ +import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, SparkListenerExecutorAdded} import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, - SparkListenerExecutorAdded} import org.apache.spark.util.Utils /** @@ -43,17 +38,9 @@ import org.apache.spark.util.Utils * applications, and require the Spark assembly to be built before they can be successfully * run. */ -class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matchers with Logging { - - // log4j configuration for the YARN containers, so that their output is collected - // by YARN instead of trying to overwrite unit-tests.log. - private val LOG4J_CONF = """ - |log4j.rootCategory=DEBUG, console - |log4j.appender.console=org.apache.log4j.ConsoleAppender - |log4j.appender.console.target=System.err - |log4j.appender.console.layout=org.apache.log4j.PatternLayout - |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n - """.stripMargin +class YarnClusterSuite extends BaseYarnClusterSuite { + + override def yarnConfig: YarnConfiguration = new YarnConfiguration() private val TEST_PYFILE = """ |import mod1, mod2 @@ -82,65 +69,6 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher | return 42 """.stripMargin - private var yarnCluster: MiniYARNCluster = _ - private var tempDir: File = _ - private var fakeSparkJar: File = _ - private var hadoopConfDir: File = _ - private var logConfDir: File = _ - - override def beforeAll() { - super.beforeAll() - - tempDir = Utils.createTempDir() - logConfDir = new File(tempDir, "log4j") - logConfDir.mkdir() - System.setProperty("SPARK_YARN_MODE", "true") - - val logConfFile = new File(logConfDir, "log4j.properties") - Files.write(LOG4J_CONF, logConfFile, UTF_8) - - yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1) - yarnCluster.init(new YarnConfiguration()) - yarnCluster.start() - - // There's a race in MiniYARNCluster in which start() may return before the RM has updated - // its address in the configuration. You can see this in the logs by noticing that when - // MiniYARNCluster prints the address, it still has port "0" assigned, although later the - // test works sometimes: - // - // INFO MiniYARNCluster: MiniYARN ResourceManager address: blah:0 - // - // That log message prints the contents of the RM_ADDRESS config variable. If you check it - // later on, it looks something like this: - // - // INFO YarnClusterSuite: RM address in configuration is blah:42631 - // - // This hack loops for a bit waiting for the port to change, and fails the test if it hasn't - // done so in a timely manner (defined to be 10 seconds). - val config = yarnCluster.getConfig() - val deadline = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(10) - while (config.get(YarnConfiguration.RM_ADDRESS).split(":")(1) == "0") { - if (System.currentTimeMillis() > deadline) { - throw new IllegalStateException("Timed out waiting for RM to come up.") - } - logDebug("RM address still not set in configuration, waiting...") - TimeUnit.MILLISECONDS.sleep(100) - } - - logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}") - - fakeSparkJar = File.createTempFile("sparkJar", null, tempDir) - hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR) - assert(hadoopConfDir.mkdir()) - File.createTempFile("token", ".txt", hadoopConfDir) - } - - override def afterAll() { - yarnCluster.stop() - System.clearProperty("SPARK_YARN_MODE") - super.afterAll() - } - test("run Spark in yarn-client mode") { testBasicYarnApp(true) } @@ -174,7 +102,7 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher } private def testBasicYarnApp(clientMode: Boolean): Unit = { - var result = File.createTempFile("result", null, tempDir) + val result = File.createTempFile("result", null, tempDir) runSpark(clientMode, mainClassName(YarnClusterDriver.getClass), appArgs = Seq(result.getAbsolutePath())) checkResult(result) @@ -224,89 +152,6 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher checkResult(executorResult, "OVERRIDDEN") } - private def runSpark( - clientMode: Boolean, - klass: String, - appArgs: Seq[String] = Nil, - sparkArgs: Seq[String] = Nil, - extraClassPath: Seq[String] = Nil, - extraJars: Seq[String] = Nil, - extraConf: Map[String, String] = Map()): Unit = { - val master = if (clientMode) "yarn-client" else "yarn-cluster" - val props = new Properties() - - props.setProperty("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath()) - - val childClasspath = logConfDir.getAbsolutePath() + - File.pathSeparator + - sys.props("java.class.path") + - File.pathSeparator + - extraClassPath.mkString(File.pathSeparator) - props.setProperty("spark.driver.extraClassPath", childClasspath) - props.setProperty("spark.executor.extraClassPath", childClasspath) - - // SPARK-4267: make sure java options are propagated correctly. - props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"") - props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"") - - yarnCluster.getConfig().foreach { e => - props.setProperty("spark.hadoop." + e.getKey(), e.getValue()) - } - - sys.props.foreach { case (k, v) => - if (k.startsWith("spark.")) { - props.setProperty(k, v) - } - } - - extraConf.foreach { case (k, v) => props.setProperty(k, v) } - - val propsFile = File.createTempFile("spark", ".properties", tempDir) - val writer = new OutputStreamWriter(new FileOutputStream(propsFile), UTF_8) - props.store(writer, "Spark properties.") - writer.close() - - val extraJarArgs = if (!extraJars.isEmpty()) Seq("--jars", extraJars.mkString(",")) else Nil - val mainArgs = - if (klass.endsWith(".py")) { - Seq(klass) - } else { - Seq("--class", klass, fakeSparkJar.getAbsolutePath()) - } - val argv = - Seq( - new File(sys.props("spark.test.home"), "bin/spark-submit").getAbsolutePath(), - "--master", master, - "--num-executors", "1", - "--properties-file", propsFile.getAbsolutePath()) ++ - extraJarArgs ++ - sparkArgs ++ - mainArgs ++ - appArgs - - Utils.executeAndGetOutput(argv, - extraEnvironment = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath())) - } - - /** - * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide - * any sort of error when the job process finishes successfully, but the job itself fails. So - * the tests enforce that something is written to a file after everything is ok to indicate - * that the job succeeded. - */ - private def checkResult(result: File): Unit = { - checkResult(result, "success") - } - - private def checkResult(result: File, expected: String): Unit = { - var resultString = Files.toString(result, UTF_8) - resultString should be (expected) - } - - private def mainClassName(klass: Class[_]): String = { - klass.getName().stripSuffix("$") - } - } private[spark] class SaveExecutorInfo extends SparkListener { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala new file mode 100644 index 0000000000000..5e8238822b90a --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala @@ -0,0 +1,109 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.deploy.yarn + +import java.io.File + +import com.google.common.base.Charsets.UTF_8 +import com.google.common.io.Files +import org.apache.commons.io.FileUtils +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.scalatest.Matchers + +import org.apache.spark._ +import org.apache.spark.network.shuffle.ShuffleTestAccessor +import org.apache.spark.network.yarn.{YarnShuffleService, YarnTestAccessor} + +/** + * Integration test for the external shuffle service with a yarn mini-cluster + */ +class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { + + override def yarnConfig: YarnConfiguration = { + val yarnConfig = new YarnConfiguration() + yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle") + yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"), + classOf[YarnShuffleService].getCanonicalName) + yarnConfig.set("spark.shuffle.service.port", "0") + yarnConfig + } + + test("external shuffle service") { + val shuffleServicePort = YarnTestAccessor.getShuffleServicePort + val shuffleService = YarnTestAccessor.getShuffleServiceInstance + + val registeredExecFile = YarnTestAccessor.getRegisteredExecutorFile(shuffleService) + + logInfo("Shuffle service port = " + shuffleServicePort) + val result = File.createTempFile("result", null, tempDir) + runSpark( + false, + mainClassName(YarnExternalShuffleDriver.getClass), + appArgs = Seq(result.getAbsolutePath(), registeredExecFile.getAbsolutePath), + extraConf = Map( + "spark.shuffle.service.enabled" -> "true", + "spark.shuffle.service.port" -> shuffleServicePort.toString + ) + ) + checkResult(result) + assert(YarnTestAccessor.getRegisteredExecutorFile(shuffleService).exists()) + } +} + +private object YarnExternalShuffleDriver extends Logging with Matchers { + + val WAIT_TIMEOUT_MILLIS = 10000 + + def main(args: Array[String]): Unit = { + if (args.length != 2) { + // scalastyle:off println + System.err.println( + s""" + |Invalid command line: ${args.mkString(" ")} + | + |Usage: ExternalShuffleDriver [result file] [registed exec file] + """.stripMargin) + // scalastyle:on println + System.exit(1) + } + + val sc = new SparkContext(new SparkConf() + .setAppName("External Shuffle Test")) + val conf = sc.getConf + val status = new File(args(0)) + val registeredExecFile = new File(args(1)) + logInfo("shuffle service executor file = " + registeredExecFile) + var result = "failure" + val execStateCopy = new File(registeredExecFile.getAbsolutePath + "_dup") + try { + val data = sc.parallelize(0 until 100, 10).map { x => (x % 10) -> x }.reduceByKey{ _ + _ }. + collect().toSet + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + data should be ((0 until 10).map{x => x -> (x * 10 + 450)}.toSet) + result = "success" + // only one process can open a leveldb file at a time, so we copy the files + FileUtils.copyDirectory(registeredExecFile, execStateCopy) + assert(!ShuffleTestAccessor.reloadRegisteredExecutors(execStateCopy).isEmpty) + } finally { + sc.stop() + FileUtils.deleteDirectory(execStateCopy) + Files.write(result, status, UTF_8) + } + } + +} diff --git a/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala b/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala new file mode 100644 index 0000000000000..aa46ec5100f0e --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.shuffle + +import java.io.{IOException, File} +import java.util.concurrent.ConcurrentMap + +import com.google.common.annotations.VisibleForTesting +import org.apache.hadoop.yarn.api.records.ApplicationId +import org.fusesource.leveldbjni.JniDBFactory +import org.iq80.leveldb.{DB, Options} + +import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo + +/** + * just a cheat to get package-visible members in tests + */ +object ShuffleTestAccessor { + + def getBlockResolver(handler: ExternalShuffleBlockHandler): ExternalShuffleBlockResolver = { + handler.blockManager + } + + def getExecutorInfo( + appId: ApplicationId, + execId: String, + resolver: ExternalShuffleBlockResolver + ): Option[ExecutorShuffleInfo] = { + val id = new AppExecId(appId.toString, execId) + Option(resolver.executors.get(id)) + } + + def registeredExecutorFile(resolver: ExternalShuffleBlockResolver): File = { + resolver.registeredExecutorFile + } + + def shuffleServiceLevelDB(resolver: ExternalShuffleBlockResolver): DB = { + resolver.db + } + + def reloadRegisteredExecutors( + file: File): ConcurrentMap[ExternalShuffleBlockResolver.AppExecId, ExecutorShuffleInfo] = { + val options: Options = new Options + options.createIfMissing(true) + val factory = new JniDBFactory + val db = factory.open(file, options) + val result = ExternalShuffleBlockResolver.reloadRegisteredExecutors(db) + db.close() + result + } + + def reloadRegisteredExecutors( + db: DB): ConcurrentMap[ExternalShuffleBlockResolver.AppExecId, ExecutorShuffleInfo] = { + ExternalShuffleBlockResolver.reloadRegisteredExecutors(db) + } +} diff --git a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala new file mode 100644 index 0000000000000..2f22cbdbeac37 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.yarn + +import java.io.{DataOutputStream, File, FileOutputStream} + +import scala.annotation.tailrec + +import org.apache.commons.io.FileUtils +import org.apache.hadoop.yarn.api.records.ApplicationId +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.server.api.{ApplicationInitializationContext, ApplicationTerminationContext} +import org.scalatest.{BeforeAndAfterEach, Matchers} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.network.shuffle.ShuffleTestAccessor +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo + +class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach { + private[yarn] var yarnConfig: YarnConfiguration = new YarnConfiguration + + override def beforeEach(): Unit = { + yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle") + yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"), + classOf[YarnShuffleService].getCanonicalName) + + yarnConfig.get("yarn.nodemanager.local-dirs").split(",").foreach { dir => + val d = new File(dir) + if (d.exists()) { + FileUtils.deleteDirectory(d) + } + FileUtils.forceMkdir(d) + logInfo(s"creating yarn.nodemanager.local-dirs: $d") + } + } + + var s1: YarnShuffleService = null + var s2: YarnShuffleService = null + var s3: YarnShuffleService = null + + override def afterEach(): Unit = { + if (s1 != null) { + s1.stop() + s1 = null + } + if (s2 != null) { + s2.stop() + s2 = null + } + if (s3 != null) { + s3.stop() + s3 = null + } + } + + test("executor state kept across NM restart") { + s1 = new YarnShuffleService + s1.init(yarnConfig) + val app1Id = ApplicationId.newInstance(0, 1) + val app1Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app1Id, null) + s1.initializeApplication(app1Data) + val app2Id = ApplicationId.newInstance(0, 2) + val app2Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app2Id, null) + s1.initializeApplication(app2Data) + + val execStateFile = s1.registeredExecutorFile + execStateFile should not be (null) + val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, "sort") + val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, "hash") + + val blockHandler = s1.blockHandler + val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler) + ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile) + + blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1) + blockResolver.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2) + ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", blockResolver) should + be (Some(shuffleInfo1)) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", blockResolver) should + be (Some(shuffleInfo2)) + + if (!execStateFile.exists()) { + @tailrec def findExistingParent(file: File): File = { + if (file == null) file + else if (file.exists()) file + else findExistingParent(file.getParentFile()) + } + val existingParent = findExistingParent(execStateFile) + assert(false, s"$execStateFile does not exist -- closest existing parent is $existingParent") + } + assert(execStateFile.exists(), s"$execStateFile did not exist") + + // now we pretend the shuffle service goes down, and comes back up + s1.stop() + s2 = new YarnShuffleService + s2.init(yarnConfig) + s2.registeredExecutorFile should be (execStateFile) + + val handler2 = s2.blockHandler + val resolver2 = ShuffleTestAccessor.getBlockResolver(handler2) + + // now we reinitialize only one of the apps, and expect yarn to tell us that app2 was stopped + // during the restart + s2.initializeApplication(app1Data) + s2.stopApplication(new ApplicationTerminationContext(app2Id)) + ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", resolver2) should be (Some(shuffleInfo1)) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver2) should be (None) + + // Act like the NM restarts one more time + s2.stop() + s3 = new YarnShuffleService + s3.init(yarnConfig) + s3.registeredExecutorFile should be (execStateFile) + + val handler3 = s3.blockHandler + val resolver3 = ShuffleTestAccessor.getBlockResolver(handler3) + + // app1 is still running + s3.initializeApplication(app1Data) + ShuffleTestAccessor.getExecutorInfo(app1Id, "exec-1", resolver3) should be (Some(shuffleInfo1)) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver3) should be (None) + s3.stop() + } + + test("removed applications should not be in registered executor file") { + s1 = new YarnShuffleService + s1.init(yarnConfig) + val app1Id = ApplicationId.newInstance(0, 1) + val app1Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app1Id, null) + s1.initializeApplication(app1Data) + val app2Id = ApplicationId.newInstance(0, 2) + val app2Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app2Id, null) + s1.initializeApplication(app2Data) + + val execStateFile = s1.registeredExecutorFile + execStateFile should not be (null) + val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, "sort") + val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, "hash") + + val blockHandler = s1.blockHandler + val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler) + ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile) + + blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1) + blockResolver.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2) + + val db = ShuffleTestAccessor.shuffleServiceLevelDB(blockResolver) + ShuffleTestAccessor.reloadRegisteredExecutors(db) should not be empty + + s1.stopApplication(new ApplicationTerminationContext(app1Id)) + ShuffleTestAccessor.reloadRegisteredExecutors(db) should not be empty + s1.stopApplication(new ApplicationTerminationContext(app2Id)) + ShuffleTestAccessor.reloadRegisteredExecutors(db) shouldBe empty + } + + test("shuffle service should be robust to corrupt registered executor file") { + s1 = new YarnShuffleService + s1.init(yarnConfig) + val app1Id = ApplicationId.newInstance(0, 1) + val app1Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app1Id, null) + s1.initializeApplication(app1Data) + + val execStateFile = s1.registeredExecutorFile + val shuffleInfo1 = new ExecutorShuffleInfo(Array("/foo", "/bar"), 3, "sort") + + val blockHandler = s1.blockHandler + val blockResolver = ShuffleTestAccessor.getBlockResolver(blockHandler) + ShuffleTestAccessor.registeredExecutorFile(blockResolver) should be (execStateFile) + + blockResolver.registerExecutor(app1Id.toString, "exec-1", shuffleInfo1) + + // now we pretend the shuffle service goes down, and comes back up. But we'll also + // make a corrupt registeredExecutor File + s1.stop() + + execStateFile.listFiles().foreach{_.delete()} + + val out = new DataOutputStream(new FileOutputStream(execStateFile + "/CURRENT")) + out.writeInt(42) + out.close() + + s2 = new YarnShuffleService + s2.init(yarnConfig) + s2.registeredExecutorFile should be (execStateFile) + + val handler2 = s2.blockHandler + val resolver2 = ShuffleTestAccessor.getBlockResolver(handler2) + + // we re-initialize app1, but since the file was corrupt there is nothing we can do about it ... + s2.initializeApplication(app1Data) + // however, when we initialize a totally new app2, everything is still happy + val app2Id = ApplicationId.newInstance(0, 2) + val app2Data: ApplicationInitializationContext = + new ApplicationInitializationContext("user", app2Id, null) + s2.initializeApplication(app2Data) + val shuffleInfo2 = new ExecutorShuffleInfo(Array("/bippy"), 5, "hash") + resolver2.registerExecutor(app2Id.toString, "exec-2", shuffleInfo2) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver2) should be (Some(shuffleInfo2)) + s2.stop() + + // another stop & restart should be fine though (eg., we recover from previous corruption) + s3 = new YarnShuffleService + s3.init(yarnConfig) + s3.registeredExecutorFile should be (execStateFile) + val handler3 = s3.blockHandler + val resolver3 = ShuffleTestAccessor.getBlockResolver(handler3) + + s3.initializeApplication(app2Data) + ShuffleTestAccessor.getExecutorInfo(app2Id, "exec-2", resolver3) should be (Some(shuffleInfo2)) + s3.stop() + + } + +} diff --git a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala new file mode 100644 index 0000000000000..db322cd18e150 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnTestAccessor.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.yarn + +import java.io.File + +/** + * just a cheat to get package-visible members in tests + */ +object YarnTestAccessor { + def getShuffleServicePort: Int = { + YarnShuffleService.boundPort + } + + def getShuffleServiceInstance: YarnShuffleService = { + YarnShuffleService.instance + } + + def getRegisteredExecutorFile(service: YarnShuffleService): File = { + service.registeredExecutorFile + } + +} From 3c462f5d87a9654c5a68fd658a40f5062029fd9a Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Fri, 21 Aug 2015 12:21:51 -0700 Subject: [PATCH 033/802] [SPARK-10130] [SQL] type coercion for IF should have children resolved first Type coercion for IF should have children resolved first, or we could meet unresolved exception. Author: Daoyuan Wang Closes #8331 from adrian-wang/spark10130. --- .../spark/sql/catalyst/analysis/HiveTypeCoercion.scala | 1 + .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index f2f2ba2f96552..2cb067f4aac91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -639,6 +639,7 @@ object HiveTypeCoercion { */ object IfCoercion extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if left.dataType != right.dataType => findTightestCommonTypeToString(left.dataType, right.dataType).map { widestType => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index da50aec17c89e..dcb4e83710982 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1679,4 +1679,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(sqlContext.table("`db.t`"), df) } } + + test("SPARK-10130 type coercion for IF should have children resolved first") { + val df = Seq((1, 1), (-1, 1)).toDF("key", "value") + df.registerTempTable("src") + checkAnswer( + sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0))) + } } From d89cc38b33815e7b99fb3389b5038a543527065d Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 21 Aug 2015 13:10:11 -0700 Subject: [PATCH 034/802] [SPARK-10122] [PYSPARK] [STREAMING] Fix getOffsetRanges bug in PySpark-Streaming transform function Details of the bug and explanations can be seen in [SPARK-10122](https://issues.apache.org/jira/browse/SPARK-10122). tdas , please help to review. Author: jerryshao Closes #8347 from jerryshao/SPARK-10122 and squashes the following commits: 4039b16 [jerryshao] Fix getOffsetRanges in transform() bug --- python/pyspark/streaming/dstream.py | 5 ++++- python/pyspark/streaming/tests.py | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 8dcb9645cdc6b..698336cfce18d 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -610,7 +610,10 @@ def __init__(self, prev, func): self.is_checkpointed = False self._jdstream_val = None - if (isinstance(prev, TransformedDStream) and + # Using type() to avoid folding the functions and compacting the DStreams which is not + # not strictly a object of TransformedDStream. + # Changed here is to avoid bug in KafkaTransformedDStream when calling offsetRanges(). + if (type(prev) is TransformedDStream and not prev.is_cached and not prev.is_checkpointed): prev_func = prev.func self.func = lambda t, rdd: func(t, prev_func(t, rdd)) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 6108c845c1efe..214d5be439003 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -850,7 +850,9 @@ def transformWithOffsetRanges(rdd): offsetRanges.append(o) return rdd - stream.transform(transformWithOffsetRanges).foreachRDD(lambda rdd: rdd.count()) + # Test whether it is ok mixing KafkaTransformedDStream and TransformedDStream together, + # only the TransformedDstreams can be folded together. + stream.transform(transformWithOffsetRanges).map(lambda kv: kv[1]).count().pprint() self.ssc.start() self.wait_for(offsetRanges, 1) From f5b028ed2f1ad6de43c8b50ebf480e1b6c047035 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Fri, 21 Aug 2015 14:19:24 -0700 Subject: [PATCH 035/802] [SPARK-9864] [DOC] [MLlib] [SQL] Replace since in scaladoc to Since annotation Author: MechCoder Closes #8352 from MechCoder/since. --- .../classification/ClassificationModel.scala | 8 +- .../classification/LogisticRegression.scala | 30 ++--- .../mllib/classification/NaiveBayes.scala | 7 +- .../spark/mllib/classification/SVM.scala | 28 ++--- .../mllib/clustering/GaussianMixture.scala | 28 ++--- .../clustering/GaussianMixtureModel.scala | 28 ++--- .../spark/mllib/clustering/KMeans.scala | 50 ++++----- .../spark/mllib/clustering/KMeansModel.scala | 27 ++--- .../apache/spark/mllib/clustering/LDA.scala | 56 ++++----- .../spark/mllib/clustering/LDAModel.scala | 69 +++++------- .../spark/mllib/clustering/LDAOptimizer.scala | 24 ++-- .../clustering/PowerIterationClustering.scala | 38 +++---- .../mllib/clustering/StreamingKMeans.scala | 35 +++--- .../BinaryClassificationMetrics.scala | 26 ++--- .../mllib/evaluation/MulticlassMetrics.scala | 20 ++-- .../mllib/evaluation/MultilabelMetrics.scala | 9 +- .../mllib/evaluation/RankingMetrics.scala | 10 +- .../mllib/evaluation/RegressionMetrics.scala | 14 +-- .../spark/mllib/fpm/AssociationRules.scala | 20 ++-- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 22 ++-- .../apache/spark/mllib/linalg/Matrices.scala | 106 ++++++++---------- .../linalg/SingularValueDecomposition.scala | 4 +- .../apache/spark/mllib/linalg/Vectors.scala | 90 +++++---------- .../linalg/distributed/BlockMatrix.scala | 88 ++++++--------- .../linalg/distributed/CoordinateMatrix.scala | 40 +++---- .../distributed/DistributedMatrix.scala | 4 +- .../linalg/distributed/IndexedRowMatrix.scala | 38 +++---- .../mllib/linalg/distributed/RowMatrix.scala | 39 +++---- .../spark/mllib/recommendation/ALS.scala | 22 ++-- .../MatrixFactorizationModel.scala | 28 +++-- .../GeneralizedLinearAlgorithm.scala | 24 ++-- .../mllib/regression/IsotonicRegression.scala | 22 ++-- .../spark/mllib/regression/LabeledPoint.scala | 7 +- .../apache/spark/mllib/regression/Lasso.scala | 25 ++--- .../mllib/regression/LinearRegression.scala | 25 ++--- .../mllib/regression/RegressionModel.scala | 12 +- .../mllib/regression/RidgeRegression.scala | 25 ++--- .../regression/StreamingLinearAlgorithm.scala | 18 +-- .../spark/mllib/stat/KernelDensity.scala | 12 +- .../stat/MultivariateOnlineSummarizer.scala | 24 ++-- .../stat/MultivariateStatisticalSummary.scala | 19 ++-- .../apache/spark/mllib/stat/Statistics.scala | 30 ++--- .../distribution/MultivariateGaussian.scala | 8 +- .../spark/mllib/tree/DecisionTree.scala | 28 +++-- .../mllib/tree/GradientBoostedTrees.scala | 20 ++-- .../spark/mllib/tree/RandomForest.scala | 20 ++-- .../spark/mllib/tree/configuration/Algo.scala | 4 +- .../tree/configuration/BoostingStrategy.scala | 12 +- .../tree/configuration/FeatureType.scala | 4 +- .../tree/configuration/QuantileStrategy.scala | 4 +- .../mllib/tree/configuration/Strategy.scala | 24 ++-- .../spark/mllib/tree/impurity/Entropy.scala | 10 +- .../spark/mllib/tree/impurity/Gini.scala | 10 +- .../spark/mllib/tree/impurity/Impurity.scala | 8 +- .../spark/mllib/tree/impurity/Variance.scala | 10 +- .../spark/mllib/tree/loss/AbsoluteError.scala | 6 +- .../spark/mllib/tree/loss/LogLoss.scala | 6 +- .../apache/spark/mllib/tree/loss/Loss.scala | 8 +- .../apache/spark/mllib/tree/loss/Losses.scala | 10 +- .../spark/mllib/tree/loss/SquaredError.scala | 6 +- .../mllib/tree/model/DecisionTreeModel.scala | 22 ++-- .../tree/model/InformationGainStats.scala | 4 +- .../apache/spark/mllib/tree/model/Node.scala | 8 +- .../spark/mllib/tree/model/Predict.scala | 4 +- .../apache/spark/mllib/tree/model/Split.scala | 4 +- .../mllib/tree/model/treeEnsembleModels.scala | 26 ++--- .../org/apache/spark/mllib/tree/package.scala | 1 - .../org/apache/spark/mllib/util/MLUtils.scala | 36 +++--- 68 files changed, 692 insertions(+), 862 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala index ba73024e3c04d..a29b425a71fd6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.classification import org.json4s.{DefaultFormats, JValue} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD @@ -36,8 +36,8 @@ trait ClassificationModel extends Serializable { * * @param testData RDD representing data points to be predicted * @return an RDD[Double] where each entry contains the corresponding prediction - * @since 0.8.0 */ + @Since("0.8.0") def predict(testData: RDD[Vector]): RDD[Double] /** @@ -45,16 +45,16 @@ trait ClassificationModel extends Serializable { * * @param testData array representing a single data point * @return predicted category from the trained model - * @since 0.8.0 */ + @Since("0.8.0") def predict(testData: Vector): Double /** * Predict values for examples stored in a JavaRDD. * @param testData JavaRDD representing data points to be predicted * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction - * @since 0.8.0 */ + @Since("0.8.0") def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 268642ac6a2f6..e03e662227d14 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.classification import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.linalg.{DenseVector, Vector} @@ -85,8 +85,8 @@ class LogisticRegressionModel ( * in Binary Logistic Regression. An example with prediction score greater than or equal to * this threshold is identified as an positive, and negative otherwise. The default value is 0.5. * It is only used for binary classification. - * @since 1.0.0 */ + @Since("1.0.0") @Experimental def setThreshold(threshold: Double): this.type = { this.threshold = Some(threshold) @@ -97,8 +97,8 @@ class LogisticRegressionModel ( * :: Experimental :: * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. * It is only used for binary classification. - * @since 1.3.0 */ + @Since("1.3.0") @Experimental def getThreshold: Option[Double] = threshold @@ -106,8 +106,8 @@ class LogisticRegressionModel ( * :: Experimental :: * Clears the threshold so that `predict` will output raw prediction scores. * It is only used for binary classification. - * @since 1.0.0 */ + @Since("1.0.0") @Experimental def clearThreshold(): this.type = { threshold = None @@ -158,9 +158,7 @@ class LogisticRegressionModel ( } } - /** - * @since 1.3.0 - */ + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, numFeatures, numClasses, weights, intercept, threshold) @@ -168,9 +166,7 @@ class LogisticRegressionModel ( override protected def formatVersion: String = "1.0" - /** - * @since 1.4.0 - */ + @Since("1.4.0") override def toString: String = { s"${super.toString}, numClasses = ${numClasses}, threshold = ${threshold.getOrElse("None")}" } @@ -178,9 +174,7 @@ class LogisticRegressionModel ( object LogisticRegressionModel extends Loader[LogisticRegressionModel] { - /** - * @since 1.3.0 - */ + @Since("1.3.0") override def load(sc: SparkContext, path: String): LogisticRegressionModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) // Hard-code class name string in case it changes in the future @@ -261,8 +255,8 @@ object LogisticRegressionWithSGD { * @param miniBatchFraction Fraction of data to be used per iteration. * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. - * @since 1.0.0 */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -284,8 +278,8 @@ object LogisticRegressionWithSGD { * @param stepSize Step size to be used for each iteration of gradient descent. * @param miniBatchFraction Fraction of data to be used per iteration. - * @since 1.0.0 */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -306,8 +300,8 @@ object LogisticRegressionWithSGD { * @param numIterations Number of iterations of gradient descent to run. * @return a LogisticRegressionModel which has the weights and offset from training. - * @since 1.0.0 */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -324,8 +318,8 @@ object LogisticRegressionWithSGD { * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @return a LogisticRegressionModel which has the weights and offset from training. - * @since 1.0.0 */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int): LogisticRegressionModel = { @@ -361,8 +355,8 @@ class LogisticRegressionWithLBFGS * Set the number of possible outcomes for k classes classification problem in * Multinomial Logistic Regression. * By default, it is binary logistic regression so k will be set to 2. - * @since 1.3.0 */ + @Since("1.3.0") @Experimental def setNumClasses(numClasses: Int): this.type = { require(numClasses > 1) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 2df91c09421e9..dab369207cc9a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -25,6 +25,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkContext, SparkException} +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{Loader, Saveable} @@ -444,8 +445,8 @@ object NaiveBayes { * * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency * vector or a count vector. - * @since 0.9.0 */ + @Since("0.9.0") def train(input: RDD[LabeledPoint]): NaiveBayesModel = { new NaiveBayes().run(input) } @@ -460,8 +461,8 @@ object NaiveBayes { * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency * vector or a count vector. * @param lambda The smoothing parameter - * @since 0.9.0 */ + @Since("0.9.0") def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = { new NaiveBayes(lambda, Multinomial).run(input) } @@ -483,8 +484,8 @@ object NaiveBayes { * * @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be * multinomial or bernoulli - * @since 0.9.0 */ + @Since("0.9.0") def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = { require(supportedModelTypes.contains(modelType), s"NaiveBayes was created with an unknown modelType: $modelType.") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 5b54feeb10467..5f87269863572 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.classification import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ @@ -46,8 +46,8 @@ class SVMModel ( * Sets the threshold that separates positive predictions from negative predictions. An example * with prediction score greater than or equal to this threshold is identified as an positive, * and negative otherwise. The default value is 0.0. - * @since 1.3.0 */ + @Since("1.3.0") @Experimental def setThreshold(threshold: Double): this.type = { this.threshold = Some(threshold) @@ -57,16 +57,16 @@ class SVMModel ( /** * :: Experimental :: * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. - * @since 1.3.0 */ + @Since("1.3.0") @Experimental def getThreshold: Option[Double] = threshold /** * :: Experimental :: * Clears the threshold so that `predict` will output raw prediction scores. - * @since 1.0.0 */ + @Since("1.0.0") @Experimental def clearThreshold(): this.type = { threshold = None @@ -84,9 +84,7 @@ class SVMModel ( } } - /** - * @since 1.3.0 - */ + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, numFeatures = weights.size, numClasses = 2, weights, intercept, threshold) @@ -94,9 +92,7 @@ class SVMModel ( override protected def formatVersion: String = "1.0" - /** - * @since 1.4.0 - */ + @Since("1.4.0") override def toString: String = { s"${super.toString}, numClasses = 2, threshold = ${threshold.getOrElse("None")}" } @@ -104,9 +100,7 @@ class SVMModel ( object SVMModel extends Loader[SVMModel] { - /** - * @since 1.3.0 - */ + @Since("1.3.0") override def load(sc: SparkContext, path: String): SVMModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) // Hard-code class name string in case it changes in the future @@ -185,8 +179,8 @@ object SVMWithSGD { * @param miniBatchFraction Fraction of data to be used per iteration. * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. - * @since 0.8.0 */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -209,8 +203,8 @@ object SVMWithSGD { * @param stepSize Step size to be used for each iteration of gradient descent. * @param regParam Regularization parameter. * @param miniBatchFraction Fraction of data to be used per iteration. - * @since 0.8.0 */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -231,8 +225,8 @@ object SVMWithSGD { * @param regParam Regularization parameter. * @param numIterations Number of iterations of gradient descent to run. * @return a SVMModel which has the weights and offset from training. - * @since 0.8.0 */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -250,8 +244,8 @@ object SVMWithSGD { * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @return a SVMModel which has the weights and offset from training. - * @since 0.8.0 */ + @Since("0.8.0") def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = { train(input, numIterations, 1.0, 0.01, 1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index bc27b1fe7390b..fcc9dfecac54f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.IndexedSeq import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, Vector => BV} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, Matrices, Vector, Vectors} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian @@ -62,8 +62,8 @@ class GaussianMixture private ( /** * Constructs a default instance. The default parameters are {k: 2, convergenceTol: 0.01, * maxIterations: 100, seed: random}. - * @since 1.3.0 */ + @Since("1.3.0") def this() = this(2, 0.01, 100, Utils.random.nextLong()) // number of samples per cluster to use when initializing Gaussians @@ -77,8 +77,8 @@ class GaussianMixture private ( * Set the initial GMM starting point, bypassing the random initialization. * You must call setK() prior to calling this method, and the condition * (model.k == this.k) must be met; failure will result in an IllegalArgumentException - * @since 1.3.0 */ + @Since("1.3.0") def setInitialModel(model: GaussianMixtureModel): this.type = { if (model.k == k) { initialModel = Some(model) @@ -90,14 +90,14 @@ class GaussianMixture private ( /** * Return the user supplied initial GMM, if supplied - * @since 1.3.0 */ + @Since("1.3.0") def getInitialModel: Option[GaussianMixtureModel] = initialModel /** * Set the number of Gaussians in the mixture model. Default: 2 - * @since 1.3.0 */ + @Since("1.3.0") def setK(k: Int): this.type = { this.k = k this @@ -105,14 +105,14 @@ class GaussianMixture private ( /** * Return the number of Gaussians in the mixture model - * @since 1.3.0 */ + @Since("1.3.0") def getK: Int = k /** * Set the maximum number of iterations to run. Default: 100 - * @since 1.3.0 */ + @Since("1.3.0") def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this @@ -120,15 +120,15 @@ class GaussianMixture private ( /** * Return the maximum number of iterations to run - * @since 1.3.0 */ + @Since("1.3.0") def getMaxIterations: Int = maxIterations /** * Set the largest change in log-likelihood at which convergence is * considered to have occurred. - * @since 1.3.0 */ + @Since("1.3.0") def setConvergenceTol(convergenceTol: Double): this.type = { this.convergenceTol = convergenceTol this @@ -137,14 +137,14 @@ class GaussianMixture private ( /** * Return the largest change in log-likelihood at which convergence is * considered to have occurred. - * @since 1.3.0 */ + @Since("1.3.0") def getConvergenceTol: Double = convergenceTol /** * Set the random seed - * @since 1.3.0 */ + @Since("1.3.0") def setSeed(seed: Long): this.type = { this.seed = seed this @@ -152,14 +152,14 @@ class GaussianMixture private ( /** * Return the random seed - * @since 1.3.0 */ + @Since("1.3.0") def getSeed: Long = seed /** * Perform expectation maximization - * @since 1.3.0 */ + @Since("1.3.0") def run(data: RDD[Vector]): GaussianMixtureModel = { val sc = data.sparkContext @@ -235,8 +235,8 @@ class GaussianMixture private ( /** * Java-friendly version of [[run()]] - * @since 1.3.0 */ + @Since("1.3.0") def run(data: JavaRDD[Vector]): GaussianMixtureModel = run(data.rdd) private def updateWeightsAndGaussians( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 2fa0473737aae..1a10a8b624218 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -24,7 +24,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian @@ -43,8 +43,8 @@ import org.apache.spark.sql.{SQLContext, Row} * the weight for Gaussian i, and weights.sum == 1 * @param gaussians Array of MultivariateGaussian where gaussians(i) represents * the Multivariate Gaussian (Normal) Distribution for Gaussian i - * @since 1.3.0 */ +@Since("1.3.0") @Experimental class GaussianMixtureModel( val weights: Array[Double], @@ -54,23 +54,21 @@ class GaussianMixtureModel( override protected def formatVersion = "1.0" - /** - * @since 1.4.0 - */ + @Since("1.4.0") override def save(sc: SparkContext, path: String): Unit = { GaussianMixtureModel.SaveLoadV1_0.save(sc, path, weights, gaussians) } /** * Number of gaussians in mixture - * @since 1.3.0 */ + @Since("1.3.0") def k: Int = weights.length /** * Maps given points to their cluster indices. - * @since 1.3.0 */ + @Since("1.3.0") def predict(points: RDD[Vector]): RDD[Int] = { val responsibilityMatrix = predictSoft(points) responsibilityMatrix.map(r => r.indexOf(r.max)) @@ -78,8 +76,8 @@ class GaussianMixtureModel( /** * Maps given point to its cluster index. - * @since 1.5.0 */ + @Since("1.5.0") def predict(point: Vector): Int = { val r = computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k) r.indexOf(r.max) @@ -87,16 +85,16 @@ class GaussianMixtureModel( /** * Java-friendly version of [[predict()]] - * @since 1.4.0 */ + @Since("1.4.0") def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]] /** * Given the input vectors, return the membership value of each vector * to all mixture components. - * @since 1.3.0 */ + @Since("1.3.0") def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = { val sc = points.sparkContext val bcDists = sc.broadcast(gaussians) @@ -108,8 +106,8 @@ class GaussianMixtureModel( /** * Given the input vector, return the membership values to all mixture components. - * @since 1.4.0 */ + @Since("1.4.0") def predictSoft(point: Vector): Array[Double] = { computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k) } @@ -133,9 +131,7 @@ class GaussianMixtureModel( } } -/** - * @since 1.4.0 - */ +@Since("1.4.0") @Experimental object GaussianMixtureModel extends Loader[GaussianMixtureModel] { @@ -186,9 +182,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { } } - /** - * @since 1.4.0 - */ + @Since("1.4.0") override def load(sc: SparkContext, path: String) : GaussianMixtureModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) implicit val formats = DefaultFormats diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 9ef6834e5ea8d..3e9545a74bef3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable.ArrayBuffer import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.{axpy, scal} import org.apache.spark.mllib.util.MLUtils @@ -49,20 +49,20 @@ class KMeans private ( /** * Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1, * initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4, seed: random}. - * @since 0.8.0 */ + @Since("0.8.0") def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4, Utils.random.nextLong()) /** * Number of clusters to create (k). - * @since 1.4.0 */ + @Since("1.4.0") def getK: Int = k /** * Set the number of clusters to create (k). Default: 2. - * @since 0.8.0 */ + @Since("0.8.0") def setK(k: Int): this.type = { this.k = k this @@ -70,14 +70,14 @@ class KMeans private ( /** * Maximum number of iterations to run. - * @since 1.4.0 */ + @Since("1.4.0") def getMaxIterations: Int = maxIterations /** * Set maximum number of iterations to run. Default: 20. - * @since 0.8.0 */ + @Since("0.8.0") def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this @@ -85,16 +85,16 @@ class KMeans private ( /** * The initialization algorithm. This can be either "random" or "k-means||". - * @since 1.4.0 */ + @Since("1.4.0") def getInitializationMode: String = initializationMode /** * Set the initialization algorithm. This can be either "random" to choose random points as * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||. - * @since 0.8.0 */ + @Since("0.8.0") def setInitializationMode(initializationMode: String): this.type = { KMeans.validateInitMode(initializationMode) this.initializationMode = initializationMode @@ -104,8 +104,8 @@ class KMeans private ( /** * :: Experimental :: * Number of runs of the algorithm to execute in parallel. - * @since 1.4.0 */ + @Since("1.4.0") @Experimental def getRuns: Int = runs @@ -114,8 +114,8 @@ class KMeans private ( * Set the number of runs of the algorithm to execute in parallel. We initialize the algorithm * this many times with random starting conditions (configured by the initialization mode), then * return the best clustering found over any run. Default: 1. - * @since 0.8.0 */ + @Since("0.8.0") @Experimental def setRuns(runs: Int): this.type = { if (runs <= 0) { @@ -127,15 +127,15 @@ class KMeans private ( /** * Number of steps for the k-means|| initialization mode - * @since 1.4.0 */ + @Since("1.4.0") def getInitializationSteps: Int = initializationSteps /** * Set the number of steps for the k-means|| initialization mode. This is an advanced * setting -- the default of 5 is almost always enough. Default: 5. - * @since 0.8.0 */ + @Since("0.8.0") def setInitializationSteps(initializationSteps: Int): this.type = { if (initializationSteps <= 0) { throw new IllegalArgumentException("Number of initialization steps must be positive") @@ -146,15 +146,15 @@ class KMeans private ( /** * The distance threshold within which we've consider centers to have converged. - * @since 1.4.0 */ + @Since("1.4.0") def getEpsilon: Double = epsilon /** * Set the distance threshold within which we've consider centers to have converged. * If all centers move less than this Euclidean distance, we stop iterating one run. - * @since 0.8.0 */ + @Since("0.8.0") def setEpsilon(epsilon: Double): this.type = { this.epsilon = epsilon this @@ -162,14 +162,14 @@ class KMeans private ( /** * The random seed for cluster initialization. - * @since 1.4.0 */ + @Since("1.4.0") def getSeed: Long = seed /** * Set the random seed for cluster initialization. - * @since 1.4.0 */ + @Since("1.4.0") def setSeed(seed: Long): this.type = { this.seed = seed this @@ -183,8 +183,8 @@ class KMeans private ( * Set the initial starting point, bypassing the random initialization or k-means|| * The condition model.k == this.k must be met, failure results * in an IllegalArgumentException. - * @since 1.4.0 */ + @Since("1.4.0") def setInitialModel(model: KMeansModel): this.type = { require(model.k == k, "mismatched cluster count") initialModel = Some(model) @@ -194,8 +194,8 @@ class KMeans private ( /** * Train a K-means model on the given set of points; `data` should be cached for high * performance, because this is an iterative algorithm. - * @since 0.8.0 */ + @Since("0.8.0") def run(data: RDD[Vector]): KMeansModel = { if (data.getStorageLevel == StorageLevel.NONE) { @@ -453,14 +453,14 @@ class KMeans private ( /** * Top-level methods for calling K-means clustering. - * @since 0.8.0 */ +@Since("0.8.0") object KMeans { // Initialization mode names - /** @since 0.8.0 */ + @Since("0.8.0") val RANDOM = "random" - /** @since 0.8.0 */ + @Since("0.8.0") val K_MEANS_PARALLEL = "k-means||" /** @@ -472,8 +472,8 @@ object KMeans { * @param runs number of parallel runs, defaults to 1. The best model is returned. * @param initializationMode initialization model, either "random" or "k-means||" (default). * @param seed random seed value for cluster initialization - * @since 1.3.0 */ + @Since("1.3.0") def train( data: RDD[Vector], k: Int, @@ -497,8 +497,8 @@ object KMeans { * @param maxIterations max number of iterations * @param runs number of parallel runs, defaults to 1. The best model is returned. * @param initializationMode initialization model, either "random" or "k-means||" (default). - * @since 0.8.0 */ + @Since("0.8.0") def train( data: RDD[Vector], k: Int, @@ -514,8 +514,8 @@ object KMeans { /** * Trains a k-means model using specified parameters and the default values for unspecified. - * @since 0.8.0 */ + @Since("0.8.0") def train( data: RDD[Vector], k: Int, @@ -525,8 +525,8 @@ object KMeans { /** * Trains a k-means model using specified parameters and the default values for unspecified. - * @since 0.8.0 */ + @Since("0.8.0") def train( data: RDD[Vector], k: Int, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 8de2087ceb4df..e425ecdd481c6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -23,6 +23,7 @@ import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.pmml.PMMLExportable @@ -34,35 +35,35 @@ import org.apache.spark.sql.Row /** * A clustering model for K-means. Each point belongs to the cluster with the closest center. - * @since 0.8.0 */ +@Since("0.8.0") class KMeansModel ( val clusterCenters: Array[Vector]) extends Saveable with Serializable with PMMLExportable { /** * A Java-friendly constructor that takes an Iterable of Vectors. - * @since 1.4.0 */ + @Since("1.4.0") def this(centers: java.lang.Iterable[Vector]) = this(centers.asScala.toArray) /** * Total number of clusters. - * @since 0.8.0 */ + @Since("0.8.0") def k: Int = clusterCenters.length /** * Returns the cluster index that a given point belongs to. - * @since 0.8.0 */ + @Since("0.8.0") def predict(point: Vector): Int = { KMeans.findClosest(clusterCentersWithNorm, new VectorWithNorm(point))._1 } /** * Maps given points to their cluster indices. - * @since 1.0.0 */ + @Since("1.0.0") def predict(points: RDD[Vector]): RDD[Int] = { val centersWithNorm = clusterCentersWithNorm val bcCentersWithNorm = points.context.broadcast(centersWithNorm) @@ -71,16 +72,16 @@ class KMeansModel ( /** * Maps given points to their cluster indices. - * @since 1.0.0 */ + @Since("1.0.0") def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]] /** * Return the K-means cost (sum of squared distances of points to their nearest center) for this * model on the given data. - * @since 0.8.0 */ + @Since("0.8.0") def computeCost(data: RDD[Vector]): Double = { val centersWithNorm = clusterCentersWithNorm val bcCentersWithNorm = data.context.broadcast(centersWithNorm) @@ -90,9 +91,7 @@ class KMeansModel ( private def clusterCentersWithNorm: Iterable[VectorWithNorm] = clusterCenters.map(new VectorWithNorm(_)) - /** - * @since 1.4.0 - */ + @Since("1.4.0") override def save(sc: SparkContext, path: String): Unit = { KMeansModel.SaveLoadV1_0.save(sc, this, path) } @@ -100,14 +99,10 @@ class KMeansModel ( override protected def formatVersion: String = "1.0" } -/** - * @since 1.4.0 - */ +@Since("1.4.0") object KMeansModel extends Loader[KMeansModel] { - /** - * @since 1.4.0 - */ + @Since("1.4.0") override def load(sc: SparkContext, path: String): KMeansModel = { KMeansModel.SaveLoadV1_0.load(sc, path) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 2a8c6acbaec61..92a321afb0ca3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseVector => BDV} import org.apache.spark.Logging -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.api.java.JavaPairRDD import org.apache.spark.graphx._ import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -43,8 +43,8 @@ import org.apache.spark.util.Utils * * @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation * (Wikipedia)]] - * @since 1.3.0 */ +@Since("1.3.0") @Experimental class LDA private ( private var k: Int, @@ -57,8 +57,8 @@ class LDA private ( /** * Constructs a LDA instance with default parameters. - * @since 1.3.0 */ + @Since("1.3.0") def this() = this(k = 10, maxIterations = 20, docConcentration = Vectors.dense(-1), topicConcentration = -1, seed = Utils.random.nextLong(), checkpointInterval = 10, ldaOptimizer = new EMLDAOptimizer) @@ -66,15 +66,15 @@ class LDA private ( /** * Number of topics to infer. I.e., the number of soft cluster centers. * - * @since 1.3.0 */ + @Since("1.3.0") def getK: Int = k /** * Number of topics to infer. I.e., the number of soft cluster centers. * (default = 10) - * @since 1.3.0 */ + @Since("1.3.0") def setK(k: Int): this.type = { require(k > 0, s"LDA k (number of clusters) must be > 0, but was set to $k") this.k = k @@ -86,8 +86,8 @@ class LDA private ( * distributions over topics ("theta"). * * This is the parameter to a Dirichlet distribution. - * @since 1.5.0 */ + @Since("1.5.0") def getAsymmetricDocConcentration: Vector = this.docConcentration /** @@ -96,8 +96,8 @@ class LDA private ( * * This method assumes the Dirichlet distribution is symmetric and can be described by a single * [[Double]] parameter. It should fail if docConcentration is asymmetric. - * @since 1.3.0 */ + @Since("1.3.0") def getDocConcentration: Double = { val parameter = docConcentration(0) if (docConcentration.size == 1) { @@ -131,8 +131,8 @@ class LDA private ( * - Values should be >= 0 * - default = uniformly (1.0 / k), following the implementation from * [[https://github.com/Blei-Lab/onlineldavb]]. - * @since 1.5.0 */ + @Since("1.5.0") def setDocConcentration(docConcentration: Vector): this.type = { require(docConcentration.size > 0, "docConcentration must have > 0 elements") this.docConcentration = docConcentration @@ -141,8 +141,8 @@ class LDA private ( /** * Replicates a [[Double]] docConcentration to create a symmetric prior. - * @since 1.3.0 */ + @Since("1.3.0") def setDocConcentration(docConcentration: Double): this.type = { this.docConcentration = Vectors.dense(docConcentration) this @@ -150,26 +150,26 @@ class LDA private ( /** * Alias for [[getAsymmetricDocConcentration]] - * @since 1.5.0 */ + @Since("1.5.0") def getAsymmetricAlpha: Vector = getAsymmetricDocConcentration /** * Alias for [[getDocConcentration]] - * @since 1.3.0 */ + @Since("1.3.0") def getAlpha: Double = getDocConcentration /** * Alias for [[setDocConcentration()]] - * @since 1.5.0 */ + @Since("1.5.0") def setAlpha(alpha: Vector): this.type = setDocConcentration(alpha) /** * Alias for [[setDocConcentration()]] - * @since 1.3.0 */ + @Since("1.3.0") def setAlpha(alpha: Double): this.type = setDocConcentration(alpha) /** @@ -180,8 +180,8 @@ class LDA private ( * * Note: The topics' distributions over terms are called "beta" in the original LDA paper * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. - * @since 1.3.0 */ + @Since("1.3.0") def getTopicConcentration: Double = this.topicConcentration /** @@ -205,8 +205,8 @@ class LDA private ( * - Value should be >= 0 * - default = (1.0 / k), following the implementation from * [[https://github.com/Blei-Lab/onlineldavb]]. - * @since 1.3.0 */ + @Since("1.3.0") def setTopicConcentration(topicConcentration: Double): this.type = { this.topicConcentration = topicConcentration this @@ -214,27 +214,27 @@ class LDA private ( /** * Alias for [[getTopicConcentration]] - * @since 1.3.0 */ + @Since("1.3.0") def getBeta: Double = getTopicConcentration /** * Alias for [[setTopicConcentration()]] - * @since 1.3.0 */ + @Since("1.3.0") def setBeta(beta: Double): this.type = setTopicConcentration(beta) /** * Maximum number of iterations for learning. - * @since 1.3.0 */ + @Since("1.3.0") def getMaxIterations: Int = maxIterations /** * Maximum number of iterations for learning. * (default = 20) - * @since 1.3.0 */ + @Since("1.3.0") def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this @@ -242,14 +242,14 @@ class LDA private ( /** * Random seed - * @since 1.3.0 */ + @Since("1.3.0") def getSeed: Long = seed /** * Random seed - * @since 1.3.0 */ + @Since("1.3.0") def setSeed(seed: Long): this.type = { this.seed = seed this @@ -257,8 +257,8 @@ class LDA private ( /** * Period (in iterations) between checkpoints. - * @since 1.3.0 */ + @Since("1.3.0") def getCheckpointInterval: Int = checkpointInterval /** @@ -268,8 +268,8 @@ class LDA private ( * [[org.apache.spark.SparkContext]], this setting is ignored. * * @see [[org.apache.spark.SparkContext#setCheckpointDir]] - * @since 1.3.0 */ + @Since("1.3.0") def setCheckpointInterval(checkpointInterval: Int): this.type = { this.checkpointInterval = checkpointInterval this @@ -280,8 +280,8 @@ class LDA private ( * :: DeveloperApi :: * * LDAOptimizer used to perform the actual calculation - * @since 1.4.0 */ + @Since("1.4.0") @DeveloperApi def getOptimizer: LDAOptimizer = ldaOptimizer @@ -289,8 +289,8 @@ class LDA private ( * :: DeveloperApi :: * * LDAOptimizer used to perform the actual calculation (default = EMLDAOptimizer) - * @since 1.4.0 */ + @Since("1.4.0") @DeveloperApi def setOptimizer(optimizer: LDAOptimizer): this.type = { this.ldaOptimizer = optimizer @@ -300,8 +300,8 @@ class LDA private ( /** * Set the LDAOptimizer used to perform the actual calculation by algorithm name. * Currently "em", "online" are supported. - * @since 1.4.0 */ + @Since("1.4.0") def setOptimizer(optimizerName: String): this.type = { this.ldaOptimizer = optimizerName.toLowerCase match { @@ -321,8 +321,8 @@ class LDA private ( * (where the vocabulary size is the length of the vector). * Document IDs must be unique and >= 0. * @return Inferred LDA model - * @since 1.3.0 */ + @Since("1.3.0") def run(documents: RDD[(Long, Vector)]): LDAModel = { val state = ldaOptimizer.initialize(documents, this) var iter = 0 @@ -339,8 +339,8 @@ class LDA private ( /** * Java-friendly version of [[run()]] - * @since 1.3.0 */ + @Since("1.3.0") def run(documents: JavaPairRDD[java.lang.Long, Vector]): LDAModel = { run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 6bc68a4c18b99..667374a2bc418 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -25,7 +25,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} import org.apache.spark.graphx.{Edge, EdgeContext, Graph, VertexId} import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} @@ -192,24 +192,16 @@ class LocalLDAModel private[clustering] ( override protected[clustering] val gammaShape: Double = 100) extends LDAModel with Serializable { - /** - * @since 1.3.0 - */ + @Since("1.3.0") override def k: Int = topics.numCols - /** - * @since 1.3.0 - */ + @Since("1.3.0") override def vocabSize: Int = topics.numRows - /** - * @since 1.3.0 - */ + @Since("1.3.0") override def topicsMatrix: Matrix = topics - /** - * @since 1.3.0 - */ + @Since("1.3.0") override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = { val brzTopics = topics.toBreeze.toDenseMatrix Range(0, k).map { topicIndex => @@ -222,9 +214,7 @@ class LocalLDAModel private[clustering] ( override protected def formatVersion = "1.0" - /** - * @since 1.5.0 - */ + @Since("1.5.0") override def save(sc: SparkContext, path: String): Unit = { LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration, gammaShape) @@ -238,16 +228,16 @@ class LocalLDAModel private[clustering] ( * * @param documents test corpus to use for calculating log likelihood * @return variational lower bound on the log likelihood of the entire corpus - * @since 1.5.0 */ + @Since("1.5.0") def logLikelihood(documents: RDD[(Long, Vector)]): Double = logLikelihoodBound(documents, docConcentration, topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k, vocabSize) /** * Java-friendly version of [[logLikelihood]] - * @since 1.5.0 */ + @Since("1.5.0") def logLikelihood(documents: JavaPairRDD[java.lang.Long, Vector]): Double = { logLikelihood(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) } @@ -258,8 +248,8 @@ class LocalLDAModel private[clustering] ( * * @param documents test corpus to use for calculating perplexity * @return Variational upper bound on log perplexity per token. - * @since 1.5.0 */ + @Since("1.5.0") def logPerplexity(documents: RDD[(Long, Vector)]): Double = { val corpusTokenCount = documents .map { case (_, termCounts) => termCounts.toArray.sum } @@ -267,9 +257,8 @@ class LocalLDAModel private[clustering] ( -logLikelihood(documents) / corpusTokenCount } - /** Java-friendly version of [[logPerplexity]] - * @since 1.5.0 - */ + /** Java-friendly version of [[logPerplexity]] */ + @Since("1.5.0") def logPerplexity(documents: JavaPairRDD[java.lang.Long, Vector]): Double = { logPerplexity(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) } @@ -347,8 +336,8 @@ class LocalLDAModel private[clustering] ( * for each document. * @param documents documents to predict topic mixture distributions for * @return An RDD of (document ID, topic mixture distribution for document) - * @since 1.3.0 */ + @Since("1.3.0") // TODO: declare in LDAModel and override once implemented in DistributedLDAModel def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = { // Double transpose because dirichletExpectation normalizes by row and we need to normalize @@ -376,8 +365,8 @@ class LocalLDAModel private[clustering] ( /** * Java-friendly version of [[topicDistributions]] - * @since 1.4.1 */ + @Since("1.4.1") def topicDistributions( documents: JavaPairRDD[java.lang.Long, Vector]): JavaPairRDD[java.lang.Long, Vector] = { val distributions = topicDistributions(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) @@ -451,9 +440,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] { } } - /** - * @since 1.5.0 - */ + @Since("1.5.0") override def load(sc: SparkContext, path: String): LocalLDAModel = { val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path) implicit val formats = DefaultFormats @@ -510,8 +497,8 @@ class DistributedLDAModel private[clustering] ( * Convert model to a local model. * The local model stores the inferred topics but not the topic distributions for training * documents. - * @since 1.3.0 */ + @Since("1.3.0") def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix, docConcentration, topicConcentration, gammaShape) @@ -521,8 +508,8 @@ class DistributedLDAModel private[clustering] ( * No guarantees are given about the ordering of the topics. * * WARNING: This matrix is collected from an RDD. Beware memory usage when vocabSize, k are large. - * @since 1.3.0 */ + @Since("1.3.0") override lazy val topicsMatrix: Matrix = { // Collect row-major topics val termTopicCounts: Array[(Int, TopicCounts)] = @@ -541,9 +528,7 @@ class DistributedLDAModel private[clustering] ( Matrices.fromBreeze(brzTopics) } - /** - * @since 1.3.0 - */ + @Since("1.3.0") override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = { val numTopics = k // Note: N_k is not needed to find the top terms, but it is needed to normalize weights @@ -582,8 +567,8 @@ class DistributedLDAModel private[clustering] ( * @return Array over topics. Each element represent as a pair of matching arrays: * (IDs for the documents, weights of the topic in these documents). * For each topic, documents are sorted in order of decreasing topic weights. - * @since 1.5.0 */ + @Since("1.5.0") def topDocumentsPerTopic(maxDocumentsPerTopic: Int): Array[(Array[Long], Array[Double])] = { val numTopics = k val topicsInQueues: Array[BoundedPriorityQueue[(Double, Long)]] = @@ -666,8 +651,8 @@ class DistributedLDAModel private[clustering] ( * - This excludes the prior; for that, use [[logPrior]]. * - Even with [[logPrior]], this is NOT the same as the data log likelihood given the * hyperparameters. - * @since 1.3.0 */ + @Since("1.3.0") lazy val logLikelihood: Double = { // TODO: generalize this for asymmetric (non-scalar) alpha val alpha = this.docConcentration(0) // To avoid closure capture of enclosing object @@ -693,8 +678,8 @@ class DistributedLDAModel private[clustering] ( /** * Log probability of the current parameter estimate: * log P(topics, topic distributions for docs | alpha, eta) - * @since 1.3.0 */ + @Since("1.3.0") lazy val logPrior: Double = { // TODO: generalize this for asymmetric (non-scalar) alpha val alpha = this.docConcentration(0) // To avoid closure capture of enclosing object @@ -725,8 +710,8 @@ class DistributedLDAModel private[clustering] ( * ("theta_doc"). * * @return RDD of (document ID, topic distribution) pairs - * @since 1.3.0 */ + @Since("1.3.0") def topicDistributions: RDD[(Long, Vector)] = { graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) => (docID.toLong, Vectors.fromBreeze(normalize(topicCounts, 1.0))) @@ -735,8 +720,8 @@ class DistributedLDAModel private[clustering] ( /** * Java-friendly version of [[topicDistributions]] - * @since 1.4.1 */ + @Since("1.4.1") def javaTopicDistributions: JavaPairRDD[java.lang.Long, Vector] = { JavaPairRDD.fromRDD(topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]]) } @@ -744,8 +729,8 @@ class DistributedLDAModel private[clustering] ( /** * For each document, return the top k weighted topics for that document and their weights. * @return RDD of (doc ID, topic indices, topic weights) - * @since 1.5.0 */ + @Since("1.5.0") def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] = { graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) => val topIndices = argtopk(topicCounts, k) @@ -761,8 +746,8 @@ class DistributedLDAModel private[clustering] ( /** * Java-friendly version of [[topTopicsPerDocument]] - * @since 1.5.0 */ + @Since("1.5.0") def javaTopTopicsPerDocument(k: Int): JavaRDD[(java.lang.Long, Array[Int], Array[Double])] = { val topics = topTopicsPerDocument(k) topics.asInstanceOf[RDD[(java.lang.Long, Array[Int], Array[Double])]].toJavaRDD() @@ -775,8 +760,8 @@ class DistributedLDAModel private[clustering] ( /** * Java-friendly version of [[topicDistributions]] - * @since 1.5.0 */ + @Since("1.5.0") override def save(sc: SparkContext, path: String): Unit = { DistributedLDAModel.SaveLoadV1_0.save( sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration, @@ -877,9 +862,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { } - /** - * @since 1.5.0 - */ + @Since("1.5.0") override def load(sc: SparkContext, path: String): DistributedLDAModel = { val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path) implicit val formats = DefaultFormats diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index cb517f9689ade..5c2aae6403bea 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -23,7 +23,7 @@ import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, all, normalize, su import breeze.numerics.{trigamma, abs, exp} import breeze.stats.distributions.{Gamma, RandBasis} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer @@ -35,8 +35,8 @@ import org.apache.spark.rdd.RDD * * An LDAOptimizer specifies which optimization/learning/inference algorithm to use, and it can * hold optimizer-specific parameters for users to set. - * @since 1.4.0 */ +@Since("1.4.0") @DeveloperApi sealed trait LDAOptimizer { @@ -74,8 +74,8 @@ sealed trait LDAOptimizer { * - Paper which clearly explains several algorithms, including EM: * Asuncion, Welling, Smyth, and Teh. * "On Smoothing and Inference for Topic Models." UAI, 2009. - * @since 1.4.0 */ +@Since("1.4.0") @DeveloperApi final class EMLDAOptimizer extends LDAOptimizer { @@ -226,8 +226,8 @@ final class EMLDAOptimizer extends LDAOptimizer { * * Original Online LDA paper: * Hoffman, Blei and Bach, "Online Learning for Latent Dirichlet Allocation." NIPS, 2010. - * @since 1.4.0 */ +@Since("1.4.0") @DeveloperApi final class OnlineLDAOptimizer extends LDAOptimizer { @@ -276,16 +276,16 @@ final class OnlineLDAOptimizer extends LDAOptimizer { /** * A (positive) learning parameter that downweights early iterations. Larger values make early * iterations count less. - * @since 1.4.0 */ + @Since("1.4.0") def getTau0: Double = this.tau0 /** * A (positive) learning parameter that downweights early iterations. Larger values make early * iterations count less. * Default: 1024, following the original Online LDA paper. - * @since 1.4.0 */ + @Since("1.4.0") def setTau0(tau0: Double): this.type = { require(tau0 > 0, s"LDA tau0 must be positive, but was set to $tau0") this.tau0 = tau0 @@ -294,16 +294,16 @@ final class OnlineLDAOptimizer extends LDAOptimizer { /** * Learning rate: exponential decay rate - * @since 1.4.0 */ + @Since("1.4.0") def getKappa: Double = this.kappa /** * Learning rate: exponential decay rate---should be between * (0.5, 1.0] to guarantee asymptotic convergence. * Default: 0.51, based on the original Online LDA paper. - * @since 1.4.0 */ + @Since("1.4.0") def setKappa(kappa: Double): this.type = { require(kappa >= 0, s"Online LDA kappa must be nonnegative, but was set to $kappa") this.kappa = kappa @@ -312,8 +312,8 @@ final class OnlineLDAOptimizer extends LDAOptimizer { /** * Mini-batch fraction, which sets the fraction of document sampled and used in each iteration - * @since 1.4.0 */ + @Since("1.4.0") def getMiniBatchFraction: Double = this.miniBatchFraction /** @@ -325,8 +325,8 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * maxIterations * miniBatchFraction >= 1. * * Default: 0.05, i.e., 5% of total documents. - * @since 1.4.0 */ + @Since("1.4.0") def setMiniBatchFraction(miniBatchFraction: Double): this.type = { require(miniBatchFraction > 0.0 && miniBatchFraction <= 1.0, s"Online LDA miniBatchFraction must be in range (0,1], but was set to $miniBatchFraction") @@ -337,16 +337,16 @@ final class OnlineLDAOptimizer extends LDAOptimizer { /** * Optimize alpha, indicates whether alpha (Dirichlet parameter for document-topic distribution) * will be optimized during training. - * @since 1.5.0 */ + @Since("1.5.0") def getOptimzeAlpha: Boolean = this.optimizeAlpha /** * Sets whether to optimize alpha parameter during training. * * Default: false - * @since 1.5.0 */ + @Since("1.5.0") def setOptimzeAlpha(optimizeAlpha: Boolean): this.type = { this.optimizeAlpha = optimizeAlpha this diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index b4733ca975152..396b36f2f6454 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -21,7 +21,7 @@ import org.json4s.JsonDSL._ import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl @@ -39,16 +39,14 @@ import org.apache.spark.{Logging, SparkContext, SparkException} * * @param k number of clusters * @param assignments an RDD of clustering [[PowerIterationClustering#Assignment]]s - * @since 1.3.0 */ +@Since("1.3.0") @Experimental class PowerIterationClusteringModel( val k: Int, val assignments: RDD[PowerIterationClustering.Assignment]) extends Saveable with Serializable { - /** - * @since 1.4.0 - */ + @Since("1.4.0") override def save(sc: SparkContext, path: String): Unit = { PowerIterationClusteringModel.SaveLoadV1_0.save(sc, this, path) } @@ -56,9 +54,7 @@ class PowerIterationClusteringModel( override protected def formatVersion: String = "1.0" } -/** - * @since 1.4.0 - */ +@Since("1.4.0") object PowerIterationClusteringModel extends Loader[PowerIterationClusteringModel] { override def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { PowerIterationClusteringModel.SaveLoadV1_0.load(sc, path) @@ -73,8 +69,8 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode val thisClassName = "org.apache.spark.mllib.clustering.PowerIterationClusteringModel" /** - * @since 1.4.0 */ + @Since("1.4.0") def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = { val sqlContext = new SQLContext(sc) import sqlContext.implicits._ @@ -87,9 +83,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode dataRDD.write.parquet(Loader.dataPath(path)) } - /** - * @since 1.4.0 - */ + @Since("1.4.0") def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { implicit val formats = DefaultFormats val sqlContext = new SQLContext(sc) @@ -136,14 +130,14 @@ class PowerIterationClustering private[clustering] ( /** * Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100, * initMode: "random"}. - * @since 1.3.0 */ + @Since("1.3.0") def this() = this(k = 2, maxIterations = 100, initMode = "random") /** * Set the number of clusters. - * @since 1.3.0 */ + @Since("1.3.0") def setK(k: Int): this.type = { this.k = k this @@ -151,8 +145,8 @@ class PowerIterationClustering private[clustering] ( /** * Set maximum number of iterations of the power iteration loop - * @since 1.3.0 */ + @Since("1.3.0") def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this @@ -161,8 +155,8 @@ class PowerIterationClustering private[clustering] ( /** * Set the initialization mode. This can be either "random" to use a random vector * as vertex properties, or "degree" to use normalized sum similarities. Default: random. - * @since 1.3.0 */ + @Since("1.3.0") def setInitializationMode(mode: String): this.type = { this.initMode = mode match { case "random" | "degree" => mode @@ -182,8 +176,8 @@ class PowerIterationClustering private[clustering] ( * assume s,,ij,, = 0.0. * * @return a [[PowerIterationClusteringModel]] that contains the clustering result - * @since 1.5.0 */ + @Since("1.5.0") def run(graph: Graph[Double, Double]): PowerIterationClusteringModel = { val w = normalize(graph) val w0 = initMode match { @@ -204,8 +198,8 @@ class PowerIterationClustering private[clustering] ( * assume s,,ij,, = 0.0. * * @return a [[PowerIterationClusteringModel]] that contains the clustering result - * @since 1.3.0 */ + @Since("1.3.0") def run(similarities: RDD[(Long, Long, Double)]): PowerIterationClusteringModel = { val w = normalize(similarities) val w0 = initMode match { @@ -217,8 +211,8 @@ class PowerIterationClustering private[clustering] ( /** * A Java-friendly version of [[PowerIterationClustering.run]]. - * @since 1.3.0 */ + @Since("1.3.0") def run(similarities: JavaRDD[(java.lang.Long, java.lang.Long, java.lang.Double)]) : PowerIterationClusteringModel = { run(similarities.rdd.asInstanceOf[RDD[(Long, Long, Double)]]) @@ -242,9 +236,7 @@ class PowerIterationClustering private[clustering] ( } } -/** - * @since 1.3.0 - */ +@Since("1.3.0") @Experimental object PowerIterationClustering extends Logging { @@ -253,8 +245,8 @@ object PowerIterationClustering extends Logging { * Cluster assignment. * @param id node id * @param cluster assigned cluster id - * @since 1.3.0 */ + @Since("1.3.0") @Experimental case class Assignment(id: Long, cluster: Int) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index a915804b02c90..41f2668ec6a7d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering import scala.reflect.ClassTag import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaSparkContext._ import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.rdd.RDD @@ -63,9 +63,8 @@ import org.apache.spark.util.random.XORShiftRandom * such that at time t + h the discount applied to the data from t is 0.5. * The definition remains the same whether the time unit is given * as batches or points. - * @since 1.2.0 - * */ +@Since("1.2.0") @Experimental class StreamingKMeansModel( override val clusterCenters: Array[Vector], @@ -73,8 +72,8 @@ class StreamingKMeansModel( /** * Perform a k-means update on a batch of data. - * @since 1.2.0 */ + @Since("1.2.0") def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = { // find nearest cluster to each point @@ -166,23 +165,23 @@ class StreamingKMeansModel( * .setRandomCenters(5, 100.0) * .trainOn(DStream) * }}} - * @since 1.2.0 */ +@Since("1.2.0") @Experimental class StreamingKMeans( var k: Int, var decayFactor: Double, var timeUnit: String) extends Logging with Serializable { - /** @since 1.2.0 */ + @Since("1.2.0") def this() = this(2, 1.0, StreamingKMeans.BATCHES) protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null) /** * Set the number of clusters. - * @since 1.2.0 */ + @Since("1.2.0") def setK(k: Int): this.type = { this.k = k this @@ -190,8 +189,8 @@ class StreamingKMeans( /** * Set the decay factor directly (for forgetful algorithms). - * @since 1.2.0 */ + @Since("1.2.0") def setDecayFactor(a: Double): this.type = { this.decayFactor = a this @@ -199,8 +198,8 @@ class StreamingKMeans( /** * Set the half life and time unit ("batches" or "points") for forgetful algorithms. - * @since 1.2.0 */ + @Since("1.2.0") def setHalfLife(halfLife: Double, timeUnit: String): this.type = { if (timeUnit != StreamingKMeans.BATCHES && timeUnit != StreamingKMeans.POINTS) { throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit) @@ -213,8 +212,8 @@ class StreamingKMeans( /** * Specify initial centers directly. - * @since 1.2.0 */ + @Since("1.2.0") def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = { model = new StreamingKMeansModel(centers, weights) this @@ -226,8 +225,8 @@ class StreamingKMeans( * @param dim Number of dimensions * @param weight Weight for each center * @param seed Random seed - * @since 1.2.0 */ + @Since("1.2.0") def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = { val random = new XORShiftRandom(seed) val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian()))) @@ -238,8 +237,8 @@ class StreamingKMeans( /** * Return the latest model. - * @since 1.2.0 */ + @Since("1.2.0") def latestModel(): StreamingKMeansModel = { model } @@ -251,8 +250,8 @@ class StreamingKMeans( * and updates the model using each batch of data from the stream. * * @param data DStream containing vector data - * @since 1.2.0 */ + @Since("1.2.0") def trainOn(data: DStream[Vector]) { assertInitialized() data.foreachRDD { (rdd, time) => @@ -262,8 +261,8 @@ class StreamingKMeans( /** * Java-friendly version of `trainOn`. - * @since 1.4.0 */ + @Since("1.4.0") def trainOn(data: JavaDStream[Vector]): Unit = trainOn(data.dstream) /** @@ -271,8 +270,8 @@ class StreamingKMeans( * * @param data DStream containing vector data * @return DStream containing predictions - * @since 1.2.0 */ + @Since("1.2.0") def predictOn(data: DStream[Vector]): DStream[Int] = { assertInitialized() data.map(model.predict) @@ -280,8 +279,8 @@ class StreamingKMeans( /** * Java-friendly version of `predictOn`. - * @since 1.4.0 */ + @Since("1.4.0") def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Integer] = { JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Integer]]) } @@ -292,8 +291,8 @@ class StreamingKMeans( * @param data DStream containing (key, feature vector) pairs * @tparam K key type * @return DStream containing the input keys and the predictions as values - * @since 1.2.0 */ + @Since("1.2.0") def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Int)] = { assertInitialized() data.mapValues(model.predict) @@ -301,8 +300,8 @@ class StreamingKMeans( /** * Java-friendly version of `predictOnValues`. - * @since 1.4.0 */ + @Since("1.4.0") def predictOnValues[K]( data: JavaPairDStream[K, Vector]): JavaPairDStream[K, java.lang.Integer] = { implicit val tag = fakeClassTag[K] diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala index 486741edd6f5a..76ae847921f44 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.evaluation -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.Logging import org.apache.spark.SparkContext._ import org.apache.spark.mllib.evaluation.binary._ @@ -41,8 +41,8 @@ import org.apache.spark.sql.DataFrame * of bins may not exactly equal numBins. The last bin in each partition may * be smaller as a result, meaning there may be an extra sample at * partition boundaries. - * @since 1.3.0 */ +@Since("1.3.0") @Experimental class BinaryClassificationMetrics( val scoreAndLabels: RDD[(Double, Double)], @@ -52,8 +52,8 @@ class BinaryClassificationMetrics( /** * Defaults `numBins` to 0. - * @since 1.0.0 */ + @Since("1.0.0") def this(scoreAndLabels: RDD[(Double, Double)]) = this(scoreAndLabels, 0) /** @@ -65,16 +65,16 @@ class BinaryClassificationMetrics( /** * Unpersist intermediate RDDs used in the computation. - * @since 1.0.0 */ + @Since("1.0.0") def unpersist() { cumulativeCounts.unpersist() } /** * Returns thresholds in descending order. - * @since 1.0.0 */ + @Since("1.0.0") def thresholds(): RDD[Double] = cumulativeCounts.map(_._1) /** @@ -82,8 +82,8 @@ class BinaryClassificationMetrics( * which is an RDD of (false positive rate, true positive rate) * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic - * @since 1.0.0 */ + @Since("1.0.0") def roc(): RDD[(Double, Double)] = { val rocCurve = createCurve(FalsePositiveRate, Recall) val sc = confusions.context @@ -94,16 +94,16 @@ class BinaryClassificationMetrics( /** * Computes the area under the receiver operating characteristic (ROC) curve. - * @since 1.0.0 */ + @Since("1.0.0") def areaUnderROC(): Double = AreaUnderCurve.of(roc()) /** * Returns the precision-recall curve, which is an RDD of (recall, precision), * NOT (precision, recall), with (0.0, 1.0) prepended to it. * @see http://en.wikipedia.org/wiki/Precision_and_recall - * @since 1.0.0 */ + @Since("1.0.0") def pr(): RDD[(Double, Double)] = { val prCurve = createCurve(Recall, Precision) val sc = confusions.context @@ -113,8 +113,8 @@ class BinaryClassificationMetrics( /** * Computes the area under the precision-recall curve. - * @since 1.0.0 */ + @Since("1.0.0") def areaUnderPR(): Double = AreaUnderCurve.of(pr()) /** @@ -122,26 +122,26 @@ class BinaryClassificationMetrics( * @param beta the beta factor in F-Measure computation. * @return an RDD of (threshold, F-Measure) pairs. * @see http://en.wikipedia.org/wiki/F1_score - * @since 1.0.0 */ + @Since("1.0.0") def fMeasureByThreshold(beta: Double): RDD[(Double, Double)] = createCurve(FMeasure(beta)) /** * Returns the (threshold, F-Measure) curve with beta = 1.0. - * @since 1.0.0 */ + @Since("1.0.0") def fMeasureByThreshold(): RDD[(Double, Double)] = fMeasureByThreshold(1.0) /** * Returns the (threshold, precision) curve. - * @since 1.0.0 */ + @Since("1.0.0") def precisionByThreshold(): RDD[(Double, Double)] = createCurve(Precision) /** * Returns the (threshold, recall) curve. - * @since 1.0.0 */ + @Since("1.0.0") def recallByThreshold(): RDD[(Double, Double)] = createCurve(Recall) private lazy val ( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index dddfa3ea5b800..02e89d921033c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.evaluation import scala.collection.Map import org.apache.spark.SparkContext._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg.{Matrices, Matrix} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame @@ -30,8 +30,8 @@ import org.apache.spark.sql.DataFrame * Evaluator for multiclass classification. * * @param predictionAndLabels an RDD of (prediction, label) pairs. - * @since 1.1.0 */ +@Since("1.1.0") @Experimental class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { @@ -65,8 +65,8 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * predicted classes are in columns, * they are ordered by class label ascending, * as in "labels" - * @since 1.1.0 */ + @Since("1.1.0") def confusionMatrix: Matrix = { val n = labels.size val values = Array.ofDim[Double](n * n) @@ -85,15 +85,15 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns true positive rate for a given label (category) * @param label the label. - * @since 1.1.0 */ + @Since("1.1.0") def truePositiveRate(label: Double): Double = recall(label) /** * Returns false positive rate for a given label (category) * @param label the label. - * @since 1.1.0 */ + @Since("1.1.0") def falsePositiveRate(label: Double): Double = { val fp = fpByClass.getOrElse(label, 0) fp.toDouble / (labelCount - labelCountByClass(label)) @@ -102,8 +102,8 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns precision for a given label (category) * @param label the label. - * @since 1.1.0 */ + @Since("1.1.0") def precision(label: Double): Double = { val tp = tpByClass(label) val fp = fpByClass.getOrElse(label, 0) @@ -113,16 +113,16 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns recall for a given label (category) * @param label the label. - * @since 1.1.0 */ + @Since("1.1.0") def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label) /** * Returns f-measure for a given label (category) * @param label the label. * @param beta the beta parameter. - * @since 1.1.0 */ + @Since("1.1.0") def fMeasure(label: Double, beta: Double): Double = { val p = precision(label) val r = recall(label) @@ -133,8 +133,8 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns f1-measure for a given label (category) * @param label the label. - * @since 1.1.0 */ + @Since("1.1.0") def fMeasure(label: Double): Double = fMeasure(label, 1.0) /** @@ -187,8 +187,8 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns weighted averaged f-measure * @param beta the beta parameter. - * @since 1.1.0 */ + @Since("1.1.0") def weightedFMeasure(beta: Double): Double = labelCountByClass.map { case (category, count) => fMeasure(category, beta) * count.toDouble / labelCount }.sum diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala index 77cb1e09bdbb5..a0a8d9c56847b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.evaluation +import org.apache.spark.annotation.Since import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ import org.apache.spark.sql.DataFrame @@ -25,8 +26,8 @@ import org.apache.spark.sql.DataFrame * Evaluator for multilabel classification. * @param predictionAndLabels an RDD of (predictions, labels) pairs, * both are non-null Arrays, each with unique elements. - * @since 1.2.0 */ +@Since("1.2.0") class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) { /** @@ -104,8 +105,8 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns precision for a given label (category) * @param label the label. - * @since 1.2.0 */ + @Since("1.2.0") def precision(label: Double): Double = { val tp = tpPerClass(label) val fp = fpPerClass.getOrElse(label, 0L) @@ -115,8 +116,8 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns recall for a given label (category) * @param label the label. - * @since 1.2.0 */ + @Since("1.2.0") def recall(label: Double): Double = { val tp = tpPerClass(label) val fn = fnPerClass.getOrElse(label, 0L) @@ -126,8 +127,8 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns f1-measure for a given label (category) * @param label the label. - * @since 1.2.0 */ + @Since("1.2.0") def f1Measure(label: Double): Double = { val p = precision(label) val r = recall(label) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala index 063fbed8cdeea..a7f43f0b110f5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} import org.apache.spark.rdd.RDD @@ -34,8 +34,8 @@ import org.apache.spark.rdd.RDD * Java users should use [[RankingMetrics$.of]] to create a [[RankingMetrics]] instance. * * @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs. - * @since 1.2.0 */ +@Since("1.2.0") @Experimental class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]) extends Logging with Serializable { @@ -56,8 +56,8 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] * * @param k the position to compute the truncated precision, must be positive * @return the average precision at the first k ranking positions - * @since 1.2.0 */ + @Since("1.2.0") def precisionAt(k: Int): Double = { require(k > 0, "ranking position k should be positive") predictionAndLabels.map { case (pred, lab) => @@ -126,8 +126,8 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])] * * @param k the position to compute the truncated ndcg, must be positive * @return the average ndcg at the first k ranking positions - * @since 1.2.0 */ + @Since("1.2.0") def ndcgAt(k: Int): Double = { require(k > 0, "ranking position k should be positive") predictionAndLabels.map { case (pred, lab) => @@ -165,8 +165,8 @@ object RankingMetrics { /** * Creates a [[RankingMetrics]] instance (for Java users). * @param predictionAndLabels a JavaRDD of (predicted ranking, ground truth set) pairs - * @since 1.4.0 */ + @Since("1.4.0") def of[E, T <: jl.Iterable[E]](predictionAndLabels: JavaRDD[(T, T)]): RankingMetrics[E] = { implicit val tag = JavaSparkContext.fakeClassTag[E] val rdd = predictionAndLabels.rdd.map { case (predictions, labels) => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index 54dfd8c099494..36a6c357c3897 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.evaluation -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.rdd.RDD import org.apache.spark.Logging import org.apache.spark.mllib.linalg.Vectors @@ -29,8 +29,8 @@ import org.apache.spark.sql.DataFrame * Evaluator for regression. * * @param predictionAndObservations an RDD of (prediction, observation) pairs. - * @since 1.2.0 */ +@Since("1.2.0") @Experimental class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extends Logging { @@ -67,8 +67,8 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend * Returns the variance explained by regression. * explainedVariance = \sum_i (\hat{y_i} - \bar{y})^2 / n * @see [[https://en.wikipedia.org/wiki/Fraction_of_variance_unexplained]] - * @since 1.2.0 */ + @Since("1.2.0") def explainedVariance: Double = { SSreg / summary.count } @@ -76,8 +76,8 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend /** * Returns the mean absolute error, which is a risk function corresponding to the * expected value of the absolute error loss or l1-norm loss. - * @since 1.2.0 */ + @Since("1.2.0") def meanAbsoluteError: Double = { summary.normL1(1) / summary.count } @@ -85,8 +85,8 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend /** * Returns the mean squared error, which is a risk function corresponding to the * expected value of the squared error loss or quadratic loss. - * @since 1.2.0 */ + @Since("1.2.0") def meanSquaredError: Double = { SSerr / summary.count } @@ -94,8 +94,8 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend /** * Returns the root mean squared error, which is defined as the square root of * the mean squared error. - * @since 1.2.0 */ + @Since("1.2.0") def rootMeanSquaredError: Double = { math.sqrt(this.meanSquaredError) } @@ -103,8 +103,8 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend /** * Returns R^2^, the unadjusted coefficient of determination. * @see [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] - * @since 1.2.0 */ + @Since("1.2.0") def r2: Double = { 1 - SSerr / SStot } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala index 7f4de77044994..ba3b447a83398 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala @@ -20,7 +20,7 @@ import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.mllib.fpm.AssociationRules.Rule @@ -33,24 +33,22 @@ import org.apache.spark.rdd.RDD * Generates association rules from a [[RDD[FreqItemset[Item]]]. This method only generates * association rules which have a single item as the consequent. * - * @since 1.5.0 */ +@Since("1.5.0") @Experimental class AssociationRules private[fpm] ( private var minConfidence: Double) extends Logging with Serializable { /** * Constructs a default instance with default parameters {minConfidence = 0.8}. - * - * @since 1.5.0 */ + @Since("1.5.0") def this() = this(0.8) /** * Sets the minimal confidence (default: `0.8`). - * - * @since 1.5.0 */ + @Since("1.5.0") def setMinConfidence(minConfidence: Double): this.type = { require(minConfidence >= 0.0 && minConfidence <= 1.0) this.minConfidence = minConfidence @@ -62,8 +60,8 @@ class AssociationRules private[fpm] ( * @param freqItemsets frequent itemset model obtained from [[FPGrowth]] * @return a [[Set[Rule[Item]]] containing the assocation rules. * - * @since 1.5.0 */ + @Since("1.5.0") def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]]): RDD[Rule[Item]] = { // For candidate rule X => Y, generate (X, (Y, freq(X union Y))) val candidates = freqItemsets.flatMap { itemset => @@ -102,8 +100,8 @@ object AssociationRules { * instead. * @tparam Item item type * - * @since 1.5.0 */ + @Since("1.5.0") @Experimental class Rule[Item] private[fpm] ( val antecedent: Array[Item], @@ -114,8 +112,8 @@ object AssociationRules { /** * Returns the confidence of the rule. * - * @since 1.5.0 */ + @Since("1.5.0") def confidence: Double = freqUnion.toDouble / freqAntecedent require(antecedent.toSet.intersect(consequent.toSet).isEmpty, { @@ -127,8 +125,8 @@ object AssociationRules { /** * Returns antecedent in a Java List. * - * @since 1.5.0 */ + @Since("1.5.0") def javaAntecedent: java.util.List[Item] = { antecedent.toList.asJava } @@ -136,8 +134,8 @@ object AssociationRules { /** * Returns consequent in a Java List. * - * @since 1.5.0 */ + @Since("1.5.0") def javaConsequent: java.util.List[Item] = { consequent.toList.asJava } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index e2370a52f4930..e37f806271680 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.mllib.fpm.FPGrowth._ @@ -39,15 +39,15 @@ import org.apache.spark.storage.StorageLevel * @param freqItemsets frequent itemset, which is an RDD of [[FreqItemset]] * @tparam Item item type * - * @since 1.3.0 */ +@Since("1.3.0") @Experimental class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable { /** * Generates association rules for the [[Item]]s in [[freqItemsets]]. * @param confidence minimal confidence of the rules produced - * @since 1.5.0 */ + @Since("1.5.0") def generateAssociationRules(confidence: Double): RDD[AssociationRules.Rule[Item]] = { val associationRules = new AssociationRules(confidence) associationRules.run(freqItemsets) @@ -71,8 +71,8 @@ class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) ex * @see [[http://en.wikipedia.org/wiki/Association_rule_learning Association rule learning * (Wikipedia)]] * - * @since 1.3.0 */ +@Since("1.3.0") @Experimental class FPGrowth private ( private var minSupport: Double, @@ -82,15 +82,15 @@ class FPGrowth private ( * Constructs a default instance with default parameters {minSupport: `0.3`, numPartitions: same * as the input data}. * - * @since 1.3.0 */ + @Since("1.3.0") def this() = this(0.3, -1) /** * Sets the minimal support level (default: `0.3`). * - * @since 1.3.0 */ + @Since("1.3.0") def setMinSupport(minSupport: Double): this.type = { this.minSupport = minSupport this @@ -99,8 +99,8 @@ class FPGrowth private ( /** * Sets the number of partitions used by parallel FP-growth (default: same as input data). * - * @since 1.3.0 */ + @Since("1.3.0") def setNumPartitions(numPartitions: Int): this.type = { this.numPartitions = numPartitions this @@ -111,8 +111,8 @@ class FPGrowth private ( * @param data input data set, each element contains a transaction * @return an [[FPGrowthModel]] * - * @since 1.3.0 */ + @Since("1.3.0") def run[Item: ClassTag](data: RDD[Array[Item]]): FPGrowthModel[Item] = { if (data.getStorageLevel == StorageLevel.NONE) { logWarning("Input data is not cached.") @@ -213,8 +213,8 @@ class FPGrowth private ( /** * :: Experimental :: * - * @since 1.3.0 */ +@Since("1.3.0") @Experimental object FPGrowth { @@ -224,15 +224,15 @@ object FPGrowth { * @param freq frequency * @tparam Item item type * - * @since 1.3.0 */ + @Since("1.3.0") class FreqItemset[Item](val items: Array[Item], val freq: Long) extends Serializable { /** * Returns items in a Java List. * - * @since 1.3.0 */ + @Since("1.3.0") def javaItems: java.util.List[Item] = { items.toList.asJava } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index dfa8910fcbc8c..28b5b4637bf17 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHash import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ @@ -227,8 +227,8 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { * @param values matrix entries in column major if not transposed or in row major otherwise * @param isTransposed whether the matrix is transposed. If true, `values` stores the matrix in * row major. - * @since 1.0.0 */ +@Since("1.0.0") @SQLUserDefinedType(udt = classOf[MatrixUDT]) class DenseMatrix( val numRows: Int, @@ -253,8 +253,8 @@ class DenseMatrix( * @param numRows number of rows * @param numCols number of columns * @param values matrix entries in column major - * @since 1.3.0 */ + @Since("1.3.0") def this(numRows: Int, numCols: Int, values: Array[Double]) = this(numRows, numCols, values, false) @@ -278,9 +278,7 @@ class DenseMatrix( private[mllib] def apply(i: Int): Double = values(i) - /** - * @since 1.3.0 - */ + @Since("1.3.0") override def apply(i: Int, j: Int): Double = values(index(i, j)) private[mllib] def index(i: Int, j: Int): Int = { @@ -291,9 +289,7 @@ class DenseMatrix( values(index(i, j)) = v } - /** - * @since 1.4.0 - */ + @Since("1.4.0") override def copy: DenseMatrix = new DenseMatrix(numRows, numCols, values.clone()) private[spark] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f), @@ -309,9 +305,7 @@ class DenseMatrix( this } - /** - * @since 1.3.0 - */ + @Since("1.3.0") override def transpose: DenseMatrix = new DenseMatrix(numCols, numRows, values, !isTransposed) private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { @@ -342,21 +336,17 @@ class DenseMatrix( } } - /** - * @since 1.5.0 - */ + @Since("1.5.0") override def numNonzeros: Int = values.count(_ != 0) - /** - * @since 1.5.0 - */ + @Since("1.5.0") override def numActives: Int = values.length /** * Generate a `SparseMatrix` from the given `DenseMatrix`. The new matrix will have isTransposed * set to false. - * @since 1.3.0 */ + @Since("1.3.0") def toSparse: SparseMatrix = { val spVals: MArrayBuilder[Double] = new MArrayBuilder.ofDouble val colPtrs: Array[Int] = new Array[Int](numCols + 1) @@ -383,8 +373,8 @@ class DenseMatrix( /** * Factory methods for [[org.apache.spark.mllib.linalg.DenseMatrix]]. - * @since 1.3.0 */ +@Since("1.3.0") object DenseMatrix { /** @@ -392,8 +382,8 @@ object DenseMatrix { * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @return `DenseMatrix` with size `numRows` x `numCols` and values of zeros - * @since 1.3.0 */ + @Since("1.3.0") def zeros(numRows: Int, numCols: Int): DenseMatrix = { require(numRows.toLong * numCols <= Int.MaxValue, s"$numRows x $numCols dense matrix is too large to allocate") @@ -405,8 +395,8 @@ object DenseMatrix { * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @return `DenseMatrix` with size `numRows` x `numCols` and values of ones - * @since 1.3.0 */ + @Since("1.3.0") def ones(numRows: Int, numCols: Int): DenseMatrix = { require(numRows.toLong * numCols <= Int.MaxValue, s"$numRows x $numCols dense matrix is too large to allocate") @@ -417,8 +407,8 @@ object DenseMatrix { * Generate an Identity Matrix in `DenseMatrix` format. * @param n number of rows and columns of the matrix * @return `DenseMatrix` with size `n` x `n` and values of ones on the diagonal - * @since 1.3.0 */ + @Since("1.3.0") def eye(n: Int): DenseMatrix = { val identity = DenseMatrix.zeros(n, n) var i = 0 @@ -435,8 +425,8 @@ object DenseMatrix { * @param numCols number of columns of the matrix * @param rng a random number generator * @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1) - * @since 1.3.0 */ + @Since("1.3.0") def rand(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { require(numRows.toLong * numCols <= Int.MaxValue, s"$numRows x $numCols dense matrix is too large to allocate") @@ -449,8 +439,8 @@ object DenseMatrix { * @param numCols number of columns of the matrix * @param rng a random number generator * @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1) - * @since 1.3.0 */ + @Since("1.3.0") def randn(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { require(numRows.toLong * numCols <= Int.MaxValue, s"$numRows x $numCols dense matrix is too large to allocate") @@ -462,8 +452,8 @@ object DenseMatrix { * @param vector a `Vector` that will form the values on the diagonal of the matrix * @return Square `DenseMatrix` with size `values.length` x `values.length` and `values` * on the diagonal - * @since 1.3.0 */ + @Since("1.3.0") def diag(vector: Vector): DenseMatrix = { val n = vector.size val matrix = DenseMatrix.zeros(n, n) @@ -498,8 +488,8 @@ object DenseMatrix { * @param isTransposed whether the matrix is transposed. If true, the matrix can be considered * Compressed Sparse Row (CSR) format, where `colPtrs` behaves as rowPtrs, * and `rowIndices` behave as colIndices, and `values` are stored in row major. - * @since 1.2.0 */ +@Since("1.2.0") @SQLUserDefinedType(udt = classOf[MatrixUDT]) class SparseMatrix( val numRows: Int, @@ -536,8 +526,8 @@ class SparseMatrix( * @param rowIndices the row index of the entry. They must be in strictly increasing * order for each column * @param values non-zero matrix entries in column major - * @since 1.3.0 */ + @Since("1.3.0") def this( numRows: Int, numCols: Int, @@ -560,8 +550,8 @@ class SparseMatrix( } /** - * @since 1.3.0 */ + @Since("1.3.0") override def apply(i: Int, j: Int): Double = { val ind = index(i, j) if (ind < 0) 0.0 else values(ind) @@ -585,9 +575,7 @@ class SparseMatrix( } } - /** - * @since 1.4.0 - */ + @Since("1.4.0") override def copy: SparseMatrix = { new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.clone()) } @@ -605,9 +593,7 @@ class SparseMatrix( this } - /** - * @since 1.3.0 - */ + @Since("1.3.0") override def transpose: SparseMatrix = new SparseMatrix(numCols, numRows, colPtrs, rowIndices, values, !isTransposed) @@ -641,28 +627,24 @@ class SparseMatrix( /** * Generate a `DenseMatrix` from the given `SparseMatrix`. The new matrix will have isTransposed * set to false. - * @since 1.3.0 */ + @Since("1.3.0") def toDense: DenseMatrix = { new DenseMatrix(numRows, numCols, toArray) } - /** - * @since 1.5.0 - */ + @Since("1.5.0") override def numNonzeros: Int = values.count(_ != 0) - /** - * @since 1.5.0 - */ + @Since("1.5.0") override def numActives: Int = values.length } /** * Factory methods for [[org.apache.spark.mllib.linalg.SparseMatrix]]. - * @since 1.3.0 */ +@Since("1.3.0") object SparseMatrix { /** @@ -673,8 +655,8 @@ object SparseMatrix { * @param numCols number of columns of the matrix * @param entries Array of (i, j, value) tuples * @return The corresponding `SparseMatrix` - * @since 1.3.0 */ + @Since("1.3.0") def fromCOO(numRows: Int, numCols: Int, entries: Iterable[(Int, Int, Double)]): SparseMatrix = { val sortedEntries = entries.toSeq.sortBy(v => (v._2, v._1)) val numEntries = sortedEntries.size @@ -722,8 +704,8 @@ object SparseMatrix { * Generate an Identity Matrix in `SparseMatrix` format. * @param n number of rows and columns of the matrix * @return `SparseMatrix` with size `n` x `n` and values of ones on the diagonal - * @since 1.3.0 */ + @Since("1.3.0") def speye(n: Int): SparseMatrix = { new SparseMatrix(n, n, (0 to n).toArray, (0 until n).toArray, Array.fill(n)(1.0)) } @@ -792,8 +774,8 @@ object SparseMatrix { * @param density the desired density for the matrix * @param rng a random number generator * @return `SparseMatrix` with size `numRows` x `numCols` and values in U(0, 1) - * @since 1.3.0 */ + @Since("1.3.0") def sprand(numRows: Int, numCols: Int, density: Double, rng: Random): SparseMatrix = { val mat = genRandMatrix(numRows, numCols, density, rng) mat.update(i => rng.nextDouble()) @@ -806,8 +788,8 @@ object SparseMatrix { * @param density the desired density for the matrix * @param rng a random number generator * @return `SparseMatrix` with size `numRows` x `numCols` and values in N(0, 1) - * @since 1.3.0 */ + @Since("1.3.0") def sprandn(numRows: Int, numCols: Int, density: Double, rng: Random): SparseMatrix = { val mat = genRandMatrix(numRows, numCols, density, rng) mat.update(i => rng.nextGaussian()) @@ -818,8 +800,8 @@ object SparseMatrix { * @param vector a `Vector` that will form the values on the diagonal of the matrix * @return Square `SparseMatrix` with size `values.length` x `values.length` and non-zero * `values` on the diagonal - * @since 1.3.0 */ + @Since("1.3.0") def spdiag(vector: Vector): SparseMatrix = { val n = vector.size vector match { @@ -835,8 +817,8 @@ object SparseMatrix { /** * Factory methods for [[org.apache.spark.mllib.linalg.Matrix]]. - * @since 1.0.0 */ +@Since("1.0.0") object Matrices { /** @@ -845,8 +827,8 @@ object Matrices { * @param numRows number of rows * @param numCols number of columns * @param values matrix entries in column major - * @since 1.0.0 */ + @Since("1.0.0") def dense(numRows: Int, numCols: Int, values: Array[Double]): Matrix = { new DenseMatrix(numRows, numCols, values) } @@ -859,8 +841,8 @@ object Matrices { * @param colPtrs the index corresponding to the start of a new column * @param rowIndices the row index of the entry * @param values non-zero matrix entries in column major - * @since 1.2.0 */ + @Since("1.2.0") def sparse( numRows: Int, numCols: Int, @@ -893,8 +875,8 @@ object Matrices { * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @return `Matrix` with size `numRows` x `numCols` and values of zeros - * @since 1.2.0 */ + @Since("1.2.0") def zeros(numRows: Int, numCols: Int): Matrix = DenseMatrix.zeros(numRows, numCols) /** @@ -902,24 +884,24 @@ object Matrices { * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix * @return `Matrix` with size `numRows` x `numCols` and values of ones - * @since 1.2.0 */ + @Since("1.2.0") def ones(numRows: Int, numCols: Int): Matrix = DenseMatrix.ones(numRows, numCols) /** * Generate a dense Identity Matrix in `Matrix` format. * @param n number of rows and columns of the matrix * @return `Matrix` with size `n` x `n` and values of ones on the diagonal - * @since 1.2.0 */ + @Since("1.2.0") def eye(n: Int): Matrix = DenseMatrix.eye(n) /** * Generate a sparse Identity Matrix in `Matrix` format. * @param n number of rows and columns of the matrix * @return `Matrix` with size `n` x `n` and values of ones on the diagonal - * @since 1.3.0 */ + @Since("1.3.0") def speye(n: Int): Matrix = SparseMatrix.speye(n) /** @@ -928,8 +910,8 @@ object Matrices { * @param numCols number of columns of the matrix * @param rng a random number generator * @return `Matrix` with size `numRows` x `numCols` and values in U(0, 1) - * @since 1.2.0 */ + @Since("1.2.0") def rand(numRows: Int, numCols: Int, rng: Random): Matrix = DenseMatrix.rand(numRows, numCols, rng) @@ -940,8 +922,8 @@ object Matrices { * @param density the desired density for the matrix * @param rng a random number generator * @return `Matrix` with size `numRows` x `numCols` and values in U(0, 1) - * @since 1.3.0 */ + @Since("1.3.0") def sprand(numRows: Int, numCols: Int, density: Double, rng: Random): Matrix = SparseMatrix.sprand(numRows, numCols, density, rng) @@ -951,8 +933,8 @@ object Matrices { * @param numCols number of columns of the matrix * @param rng a random number generator * @return `Matrix` with size `numRows` x `numCols` and values in N(0, 1) - * @since 1.2.0 */ + @Since("1.2.0") def randn(numRows: Int, numCols: Int, rng: Random): Matrix = DenseMatrix.randn(numRows, numCols, rng) @@ -963,8 +945,8 @@ object Matrices { * @param density the desired density for the matrix * @param rng a random number generator * @return `Matrix` with size `numRows` x `numCols` and values in N(0, 1) - * @since 1.3.0 */ + @Since("1.3.0") def sprandn(numRows: Int, numCols: Int, density: Double, rng: Random): Matrix = SparseMatrix.sprandn(numRows, numCols, density, rng) @@ -973,8 +955,8 @@ object Matrices { * @param vector a `Vector` that will form the values on the diagonal of the matrix * @return Square `Matrix` with size `values.length` x `values.length` and `values` * on the diagonal - * @since 1.2.0 */ + @Since("1.2.0") def diag(vector: Vector): Matrix = DenseMatrix.diag(vector) /** @@ -983,8 +965,8 @@ object Matrices { * a sparse matrix. If the Array is empty, an empty `DenseMatrix` will be returned. * @param matrices array of matrices * @return a single `Matrix` composed of the matrices that were horizontally concatenated - * @since 1.3.0 */ + @Since("1.3.0") def horzcat(matrices: Array[Matrix]): Matrix = { if (matrices.isEmpty) { return new DenseMatrix(0, 0, Array[Double]()) @@ -1042,8 +1024,8 @@ object Matrices { * a sparse matrix. If the Array is empty, an empty `DenseMatrix` will be returned. * @param matrices array of matrices * @return a single `Matrix` composed of the matrices that were vertically concatenated - * @since 1.3.0 */ + @Since("1.3.0") def vertcat(matrices: Array[Matrix]): Matrix = { if (matrices.isEmpty) { return new DenseMatrix(0, 0, Array[Double]()) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala index 8f504f6984cb0..a37aca99d5e72 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala @@ -17,13 +17,13 @@ package org.apache.spark.mllib.linalg -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} /** * :: Experimental :: * Represents singular value decomposition (SVD) factors. - * @since 1.0.0 */ +@Since("1.0.0") @Experimental case class SingularValueDecomposition[UType, VType](U: UType, s: Vector, V: VType) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 52ef7be3b38be..3d577edbe23e1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -26,7 +26,7 @@ import scala.collection.JavaConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} import org.apache.spark.SparkException -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.{AlphaComponent, Since} import org.apache.spark.mllib.util.NumericParser import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow @@ -240,14 +240,14 @@ class VectorUDT extends UserDefinedType[Vector] { * Factory methods for [[org.apache.spark.mllib.linalg.Vector]]. * We don't use the name `Vector` because Scala imports * [[scala.collection.immutable.Vector]] by default. - * @since 1.0.0 */ +@Since("1.0.0") object Vectors { /** * Creates a dense vector from its values. - * @since 1.0.0 */ + @Since("1.0.0") @varargs def dense(firstValue: Double, otherValues: Double*): Vector = new DenseVector((firstValue +: otherValues).toArray) @@ -255,8 +255,8 @@ object Vectors { // A dummy implicit is used to avoid signature collision with the one generated by @varargs. /** * Creates a dense vector from a double array. - * @since 1.0.0 */ + @Since("1.0.0") def dense(values: Array[Double]): Vector = new DenseVector(values) /** @@ -265,8 +265,8 @@ object Vectors { * @param size vector size. * @param indices index array, must be strictly increasing. * @param values value array, must have the same length as indices. - * @since 1.0.0 */ + @Since("1.0.0") def sparse(size: Int, indices: Array[Int], values: Array[Double]): Vector = new SparseVector(size, indices, values) @@ -275,8 +275,8 @@ object Vectors { * * @param size vector size. * @param elements vector elements in (index, value) pairs. - * @since 1.0.0 */ + @Since("1.0.0") def sparse(size: Int, elements: Seq[(Int, Double)]): Vector = { require(size > 0, "The size of the requested sparse vector must be greater than 0.") @@ -297,8 +297,8 @@ object Vectors { * * @param size vector size. * @param elements vector elements in (index, value) pairs. - * @since 1.0.0 */ + @Since("1.0.0") def sparse(size: Int, elements: JavaIterable[(JavaInteger, JavaDouble)]): Vector = { sparse(size, elements.asScala.map { case (i, x) => (i.intValue(), x.doubleValue()) @@ -310,16 +310,16 @@ object Vectors { * * @param size vector size * @return a zero vector - * @since 1.1.0 */ + @Since("1.1.0") def zeros(size: Int): Vector = { new DenseVector(new Array[Double](size)) } /** * Parses a string resulted from [[Vector.toString]] into a [[Vector]]. - * @since 1.1.0 */ + @Since("1.1.0") def parse(s: String): Vector = { parseNumeric(NumericParser.parse(s)) } @@ -362,8 +362,8 @@ object Vectors { * @param vector input vector. * @param p norm. * @return norm in L^p^ space. - * @since 1.3.0 */ + @Since("1.3.0") def norm(vector: Vector, p: Double): Double = { require(p >= 1.0, "To compute the p-norm of the vector, we require that you specify a p>=1. " + s"You specified p=$p.") @@ -415,8 +415,8 @@ object Vectors { * @param v1 first Vector. * @param v2 second Vector. * @return squared distance between two Vectors. - * @since 1.3.0 */ + @Since("1.3.0") def sqdist(v1: Vector, v2: Vector): Double = { require(v1.size == v2.size, s"Vector dimensions do not match: Dim(v1)=${v1.size} and Dim(v2)" + s"=${v2.size}.") @@ -529,33 +529,25 @@ object Vectors { /** * A dense vector represented by a value array. - * @since 1.0.0 */ +@Since("1.0.0") @SQLUserDefinedType(udt = classOf[VectorUDT]) class DenseVector(val values: Array[Double]) extends Vector { - /** - * @since 1.0.0 - */ + @Since("1.0.0") override def size: Int = values.length override def toString: String = values.mkString("[", ",", "]") - /** - * @since 1.0.0 - */ + @Since("1.0.0") override def toArray: Array[Double] = values private[spark] override def toBreeze: BV[Double] = new BDV[Double](values) - /** - * @since 1.0.0 - */ + @Since("1.0.0") override def apply(i: Int): Double = values(i) - /** - * @since 1.1.0 - */ + @Since("1.1.0") override def copy: DenseVector = { new DenseVector(values.clone()) } @@ -587,14 +579,10 @@ class DenseVector(val values: Array[Double]) extends Vector { result } - /** - * @since 1.4.0 - */ + @Since("1.4.0") override def numActives: Int = size - /** - * @since 1.4.0 - */ + @Since("1.4.0") override def numNonzeros: Int = { // same as values.count(_ != 0.0) but faster var nnz = 0 @@ -606,9 +594,7 @@ class DenseVector(val values: Array[Double]) extends Vector { nnz } - /** - * @since 1.4.0 - */ + @Since("1.4.0") override def toSparse: SparseVector = { val nnz = numNonzeros val ii = new Array[Int](nnz) @@ -624,9 +610,7 @@ class DenseVector(val values: Array[Double]) extends Vector { new SparseVector(size, ii, vv) } - /** - * @since 1.5.0 - */ + @Since("1.5.0") override def argmax: Int = { if (size == 0) { -1 @@ -646,9 +630,7 @@ class DenseVector(val values: Array[Double]) extends Vector { } } -/** - * @since 1.3.0 - */ +@Since("1.3.0") object DenseVector { /** Extracts the value array from a dense vector. */ def unapply(dv: DenseVector): Option[Array[Double]] = Some(dv.values) @@ -660,8 +642,8 @@ object DenseVector { * @param size size of the vector. * @param indices index array, assume to be strictly increasing. * @param values value array, must have the same length as the index array. - * @since 1.0.0 */ +@Since("1.0.0") @SQLUserDefinedType(udt = classOf[VectorUDT]) class SparseVector( override val size: Int, @@ -677,9 +659,7 @@ class SparseVector( override def toString: String = s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})" - /** - * @since 1.0.0 - */ + @Since("1.0.0") override def toArray: Array[Double] = { val data = new Array[Double](size) var i = 0 @@ -691,9 +671,7 @@ class SparseVector( data } - /** - * @since 1.1.0 - */ + @Since("1.1.0") override def copy: SparseVector = { new SparseVector(size, indices.clone(), values.clone()) } @@ -734,14 +712,10 @@ class SparseVector( result } - /** - * @since 1.4.0 - */ + @Since("1.4.0") override def numActives: Int = values.length - /** - * @since 1.4.0 - */ + @Since("1.4.0") override def numNonzeros: Int = { var nnz = 0 values.foreach { v => @@ -752,9 +726,7 @@ class SparseVector( nnz } - /** - * @since 1.4.0 - */ + @Since("1.4.0") override def toSparse: SparseVector = { val nnz = numNonzeros if (nnz == numActives) { @@ -774,9 +746,7 @@ class SparseVector( } } - /** - * @since 1.5.0 - */ + @Since("1.5.0") override def argmax: Int = { if (size == 0) { -1 @@ -847,9 +817,7 @@ class SparseVector( } } -/** - * @since 1.3.0 - */ +@Since("1.3.0") object SparseVector { def unapply(sv: SparseVector): Option[(Int, Array[Int], Array[Double])] = Some((sv.size, sv.indices, sv.values)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index cfb6680a18b34..94376c24a7ac6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.{Logging, Partitioner, SparkException} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -128,9 +128,8 @@ private[mllib] object GridPartitioner { * the number of rows will be calculated when `numRows` is invoked. * @param nCols Number of columns of this matrix. If the supplied value is less than or equal to * zero, the number of columns will be calculated when `numCols` is invoked. - * @since 1.3.0 - * */ +@Since("1.3.0") @Experimental class BlockMatrix( val blocks: RDD[((Int, Int), Matrix)], @@ -151,10 +150,8 @@ class BlockMatrix( * rows are not required to have the given number of rows * @param colsPerBlock Number of columns that make up each block. The blocks forming the final * columns are not required to have the given number of columns - * - * @since 1.3.0 - * */ + @Since("1.3.0") def this( blocks: RDD[((Int, Int), Matrix)], rowsPerBlock: Int, @@ -162,20 +159,13 @@ class BlockMatrix( this(blocks, rowsPerBlock, colsPerBlock, 0L, 0L) } - /** - * @since 1.3.0 - * */ - + @Since("1.3.0") override def numRows(): Long = { if (nRows <= 0L) estimateDim() nRows } - /** - * - * @since 1.3.0 - */ - + @Since("1.3.0") override def numCols(): Long = { if (nCols <= 0L) estimateDim() nCols @@ -206,8 +196,8 @@ class BlockMatrix( /** * Validates the block matrix info against the matrix data (`blocks`) and throws an exception if * any error is found. - * @since 1.3.0 */ + @Since("1.3.0") def validate(): Unit = { logDebug("Validating BlockMatrix...") // check if the matrix is larger than the claimed dimensions @@ -243,25 +233,22 @@ class BlockMatrix( logDebug("BlockMatrix is valid!") } - /** Caches the underlying RDD. - * @since 1.3.0 - * */ + /** Caches the underlying RDD. */ + @Since("1.3.0") def cache(): this.type = { blocks.cache() this } - /** Persists the underlying RDD with the specified storage level. - * @since 1.3.0 - * */ + /** Persists the underlying RDD with the specified storage level. */ + @Since("1.3.0") def persist(storageLevel: StorageLevel): this.type = { blocks.persist(storageLevel) this } - /** Converts to CoordinateMatrix. - * @since 1.3.0 - * */ + /** Converts to CoordinateMatrix. */ + @Since("1.3.0") def toCoordinateMatrix(): CoordinateMatrix = { val entryRDD = blocks.flatMap { case ((blockRowIndex, blockColIndex), mat) => val rowStart = blockRowIndex.toLong * rowsPerBlock @@ -275,9 +262,8 @@ class BlockMatrix( new CoordinateMatrix(entryRDD, numRows(), numCols()) } - /** Converts to IndexedRowMatrix. The number of columns must be within the integer range. - * @since 1.3.0 - * */ + /** Converts to IndexedRowMatrix. The number of columns must be within the integer range. */ + @Since("1.3.0") def toIndexedRowMatrix(): IndexedRowMatrix = { require(numCols() < Int.MaxValue, "The number of columns must be within the integer range. " + s"numCols: ${numCols()}") @@ -285,9 +271,8 @@ class BlockMatrix( toCoordinateMatrix().toIndexedRowMatrix() } - /** Collect the distributed matrix on the driver as a `DenseMatrix`. - * @since 1.3.0 - * */ + /** Collect the distributed matrix on the driver as a `DenseMatrix`. */ + @Since("1.3.0") def toLocalMatrix(): Matrix = { require(numRows() < Int.MaxValue, "The number of rows of this matrix should be less than " + s"Int.MaxValue. Currently numRows: ${numRows()}") @@ -312,11 +297,11 @@ class BlockMatrix( new DenseMatrix(m, n, values) } - /** Transpose this `BlockMatrix`. Returns a new `BlockMatrix` instance sharing the - * same underlying data. Is a lazy operation. - * @since 1.3.0 - * - * */ + /** + * Transpose this `BlockMatrix`. Returns a new `BlockMatrix` instance sharing the + * same underlying data. Is a lazy operation. + */ + @Since("1.3.0") def transpose: BlockMatrix = { val transposedBlocks = blocks.map { case ((blockRowIndex, blockColIndex), mat) => ((blockColIndex, blockRowIndex), mat.transpose) @@ -330,13 +315,14 @@ class BlockMatrix( new BDM[Double](localMat.numRows, localMat.numCols, localMat.toArray) } - /** Adds two block matrices together. The matrices must have the same size and matching - * `rowsPerBlock` and `colsPerBlock` values. If one of the blocks that are being added are - * instances of [[SparseMatrix]], the resulting sub matrix will also be a [[SparseMatrix]], even - * if it is being added to a [[DenseMatrix]]. If two dense matrices are added, the output will - * also be a [[DenseMatrix]]. - * @since 1.3.0 - */ + /** + * Adds two block matrices together. The matrices must have the same size and matching + * `rowsPerBlock` and `colsPerBlock` values. If one of the blocks that are being added are + * instances of [[SparseMatrix]], the resulting sub matrix will also be a [[SparseMatrix]], even + * if it is being added to a [[DenseMatrix]]. If two dense matrices are added, the output will + * also be a [[DenseMatrix]]. + */ + @Since("1.3.0") def add(other: BlockMatrix): BlockMatrix = { require(numRows() == other.numRows(), "Both matrices must have the same number of rows. " + s"A.numRows: ${numRows()}, B.numRows: ${other.numRows()}") @@ -364,14 +350,14 @@ class BlockMatrix( } } - /** Left multiplies this [[BlockMatrix]] to `other`, another [[BlockMatrix]]. The `colsPerBlock` - * of this matrix must equal the `rowsPerBlock` of `other`. If `other` contains - * [[SparseMatrix]], they will have to be converted to a [[DenseMatrix]]. The output - * [[BlockMatrix]] will only consist of blocks of [[DenseMatrix]]. This may cause - * some performance issues until support for multiplying two sparse matrices is added. - * - * @since 1.3.0 - */ + /** + * Left multiplies this [[BlockMatrix]] to `other`, another [[BlockMatrix]]. The `colsPerBlock` + * of this matrix must equal the `rowsPerBlock` of `other`. If `other` contains + * [[SparseMatrix]], they will have to be converted to a [[DenseMatrix]]. The output + * [[BlockMatrix]] will only consist of blocks of [[DenseMatrix]]. This may cause + * some performance issues until support for multiplying two sparse matrices is added. + */ + @Since("1.3.0") def multiply(other: BlockMatrix): BlockMatrix = { require(numCols() == other.numRows(), "The number of columns of A and the number of rows " + s"of B must be equal. A.numCols: ${numCols()}, B.numRows: ${other.numRows()}. If you " + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala index 2b751e45dd76c..4bb27ec840902 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Matrix, SparseMatrix, Vectors} @@ -29,8 +29,8 @@ import org.apache.spark.mllib.linalg.{Matrix, SparseMatrix, Vectors} * @param i row index * @param j column index * @param value value of the entry - * @since 1.0.0 */ +@Since("1.0.0") @Experimental case class MatrixEntry(i: Long, j: Long, value: Double) @@ -43,22 +43,20 @@ case class MatrixEntry(i: Long, j: Long, value: Double) * be determined by the max row index plus one. * @param nCols number of columns. A non-positive value means unknown, and then the number of * columns will be determined by the max column index plus one. - * @since 1.0.0 */ +@Since("1.0.0") @Experimental class CoordinateMatrix( val entries: RDD[MatrixEntry], private var nRows: Long, private var nCols: Long) extends DistributedMatrix { - /** Alternative constructor leaving matrix dimensions to be determined automatically. - * @since 1.0.0 - * */ + /** Alternative constructor leaving matrix dimensions to be determined automatically. */ + @Since("1.0.0") def this(entries: RDD[MatrixEntry]) = this(entries, 0L, 0L) - /** Gets or computes the number of columns. - * @since 1.0.0 - * */ + /** Gets or computes the number of columns. */ + @Since("1.0.0") override def numCols(): Long = { if (nCols <= 0L) { computeSize() @@ -66,9 +64,8 @@ class CoordinateMatrix( nCols } - /** Gets or computes the number of rows. - * @since 1.0.0 - * */ + /** Gets or computes the number of rows. */ + @Since("1.0.0") override def numRows(): Long = { if (nRows <= 0L) { computeSize() @@ -76,16 +73,14 @@ class CoordinateMatrix( nRows } - /** Transposes this CoordinateMatrix. - * @since 1.3.0 - * */ + /** Transposes this CoordinateMatrix. */ + @Since("1.3.0") def transpose(): CoordinateMatrix = { new CoordinateMatrix(entries.map(x => MatrixEntry(x.j, x.i, x.value)), numCols(), numRows()) } - /** Converts to IndexedRowMatrix. The number of columns must be within the integer range. - * @since 1.0.0 - * */ + /** Converts to IndexedRowMatrix. The number of columns must be within the integer range. */ + @Since("1.0.0") def toIndexedRowMatrix(): IndexedRowMatrix = { val nl = numCols() if (nl > Int.MaxValue) { @@ -104,15 +99,14 @@ class CoordinateMatrix( /** * Converts to RowMatrix, dropping row indices after grouping by row index. * The number of columns must be within the integer range. - * @since 1.0.0 */ + @Since("1.0.0") def toRowMatrix(): RowMatrix = { toIndexedRowMatrix().toRowMatrix() } - /** Converts to BlockMatrix. Creates blocks of [[SparseMatrix]] with size 1024 x 1024. - * @since 1.3.0 - * */ + /** Converts to BlockMatrix. Creates blocks of [[SparseMatrix]] with size 1024 x 1024. */ + @Since("1.3.0") def toBlockMatrix(): BlockMatrix = { toBlockMatrix(1024, 1024) } @@ -124,8 +118,8 @@ class CoordinateMatrix( * @param colsPerBlock The number of columns of each block. The blocks at the right edge may have * a smaller value. Must be an integer value greater than 0. * @return a [[BlockMatrix]] - * @since 1.3.0 */ + @Since("1.3.0") def toBlockMatrix(rowsPerBlock: Int, colsPerBlock: Int): BlockMatrix = { require(rowsPerBlock > 0, s"rowsPerBlock needs to be greater than 0. rowsPerBlock: $rowsPerBlock") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala index 98e90af84abac..e51327ebb7b58 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala @@ -19,10 +19,12 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} +import org.apache.spark.annotation.Since + /** * Represents a distributively stored matrix backed by one or more RDDs. - * @since 1.0.0 */ +@Since("1.0.0") trait DistributedMatrix extends Serializable { /** Gets or computes the number of rows. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index a09f88ce28e58..6d2c05a47d049 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.SingularValueDecomposition @@ -27,8 +27,8 @@ import org.apache.spark.mllib.linalg.SingularValueDecomposition /** * :: Experimental :: * Represents a row of [[org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix]]. - * @since 1.0.0 */ +@Since("1.0.0") @Experimental case class IndexedRow(index: Long, vector: Vector) @@ -42,23 +42,19 @@ case class IndexedRow(index: Long, vector: Vector) * be determined by the max row index plus one. * @param nCols number of columns. A non-positive value means unknown, and then the number of * columns will be determined by the size of the first row. - * @since 1.0.0 */ +@Since("1.0.0") @Experimental class IndexedRowMatrix( val rows: RDD[IndexedRow], private var nRows: Long, private var nCols: Int) extends DistributedMatrix { - /** Alternative constructor leaving matrix dimensions to be determined automatically. - * @since 1.0.0 - * */ + /** Alternative constructor leaving matrix dimensions to be determined automatically. */ + @Since("1.0.0") def this(rows: RDD[IndexedRow]) = this(rows, 0L, 0) - /** - * - * @since 1.0.0 - */ + @Since("1.0.0") override def numCols(): Long = { if (nCols <= 0) { // Calling `first` will throw an exception if `rows` is empty. @@ -67,10 +63,7 @@ class IndexedRowMatrix( nCols } - /** - * - * @since 1.0.0 - */ + @Since("1.0.0") override def numRows(): Long = { if (nRows <= 0L) { // Reduce will throw an exception if `rows` is empty. @@ -82,15 +75,14 @@ class IndexedRowMatrix( /** * Drops row indices and converts this matrix to a * [[org.apache.spark.mllib.linalg.distributed.RowMatrix]]. - * @since 1.0.0 */ + @Since("1.0.0") def toRowMatrix(): RowMatrix = { new RowMatrix(rows.map(_.vector), 0L, nCols) } - /** Converts to BlockMatrix. Creates blocks of [[SparseMatrix]] with size 1024 x 1024. - * @since 1.3.0 - * */ + /** Converts to BlockMatrix. Creates blocks of [[SparseMatrix]] with size 1024 x 1024. */ + @Since("1.3.0") def toBlockMatrix(): BlockMatrix = { toBlockMatrix(1024, 1024) } @@ -102,8 +94,8 @@ class IndexedRowMatrix( * @param colsPerBlock The number of columns of each block. The blocks at the right edge may have * a smaller value. Must be an integer value greater than 0. * @return a [[BlockMatrix]] - * @since 1.3.0 */ + @Since("1.3.0") def toBlockMatrix(rowsPerBlock: Int, colsPerBlock: Int): BlockMatrix = { // TODO: This implementation may be optimized toCoordinateMatrix().toBlockMatrix(rowsPerBlock, colsPerBlock) @@ -112,8 +104,8 @@ class IndexedRowMatrix( /** * Converts this matrix to a * [[org.apache.spark.mllib.linalg.distributed.CoordinateMatrix]]. - * @since 1.3.0 */ + @Since("1.3.0") def toCoordinateMatrix(): CoordinateMatrix = { val entries = rows.flatMap { row => val rowIndex = row.index @@ -149,8 +141,8 @@ class IndexedRowMatrix( * @param rCond the reciprocal condition number. All singular values smaller than rCond * sigma(0) * are treated as zero, where sigma(0) is the largest singular value. * @return SingularValueDecomposition(U, s, V) - * @since 1.0.0 */ + @Since("1.0.0") def computeSVD( k: Int, computeU: Boolean = false, @@ -176,8 +168,8 @@ class IndexedRowMatrix( * * @param B a local matrix whose number of rows must match the number of columns of this matrix * @return an IndexedRowMatrix representing the product, which preserves partitioning - * @since 1.0.0 */ + @Since("1.0.0") def multiply(B: Matrix): IndexedRowMatrix = { val mat = toRowMatrix().multiply(B) val indexedRows = rows.map(_.index).zip(mat.rows).map { case (i, v) => @@ -188,8 +180,8 @@ class IndexedRowMatrix( /** * Computes the Gramian matrix `A^T A`. - * @since 1.0.0 */ + @Since("1.0.0") def computeGramianMatrix(): Matrix = { toRowMatrix().computeGramianMatrix() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index b2e94f2dd6201..78036eba5c3e6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -28,7 +28,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging import org.apache.spark.SparkContext._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} import org.apache.spark.rdd.RDD @@ -44,22 +44,20 @@ import org.apache.spark.storage.StorageLevel * be determined by the number of records in the RDD `rows`. * @param nCols number of columns. A non-positive value means unknown, and then the number of * columns will be determined by the size of the first row. - * @since 1.0.0 */ +@Since("1.0.0") @Experimental class RowMatrix( val rows: RDD[Vector], private var nRows: Long, private var nCols: Int) extends DistributedMatrix with Logging { - /** Alternative constructor leaving matrix dimensions to be determined automatically. - * @since 1.0.0 - * */ + /** Alternative constructor leaving matrix dimensions to be determined automatically. */ + @Since("1.0.0") def this(rows: RDD[Vector]) = this(rows, 0L, 0) - /** Gets or computes the number of columns. - * @since 1.0.0 - * */ + /** Gets or computes the number of columns. */ + @Since("1.0.0") override def numCols(): Long = { if (nCols <= 0) { try { @@ -74,9 +72,8 @@ class RowMatrix( nCols } - /** Gets or computes the number of rows. - * @since 1.0.0 - * */ + /** Gets or computes the number of rows. */ + @Since("1.0.0") override def numRows(): Long = { if (nRows <= 0L) { nRows = rows.count() @@ -114,8 +111,8 @@ class RowMatrix( /** * Computes the Gramian matrix `A^T A`. - * @since 1.0.0 */ + @Since("1.0.0") def computeGramianMatrix(): Matrix = { val n = numCols().toInt checkNumColumns(n) @@ -185,8 +182,8 @@ class RowMatrix( * @param rCond the reciprocal condition number. All singular values smaller than rCond * sigma(0) * are treated as zero, where sigma(0) is the largest singular value. * @return SingularValueDecomposition(U, s, V). U = null if computeU = false. - * @since 1.0.0 */ + @Since("1.0.0") def computeSVD( k: Int, computeU: Boolean = false, @@ -326,8 +323,8 @@ class RowMatrix( /** * Computes the covariance matrix, treating each row as an observation. * @return a local dense matrix of size n x n - * @since 1.0.0 */ + @Since("1.0.0") def computeCovariance(): Matrix = { val n = numCols().toInt checkNumColumns(n) @@ -380,8 +377,8 @@ class RowMatrix( * * @param k number of top principal components. * @return a matrix of size n-by-k, whose columns are principal components - * @since 1.0.0 */ + @Since("1.0.0") def computePrincipalComponents(k: Int): Matrix = { val n = numCols().toInt require(k > 0 && k <= n, s"k = $k out of range (0, n = $n]") @@ -399,8 +396,8 @@ class RowMatrix( /** * Computes column-wise summary statistics. - * @since 1.0.0 */ + @Since("1.0.0") def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = { val summary = rows.treeAggregate(new MultivariateOnlineSummarizer)( (aggregator, data) => aggregator.add(data), @@ -415,8 +412,8 @@ class RowMatrix( * @param B a local matrix whose number of rows must match the number of columns of this matrix * @return a [[org.apache.spark.mllib.linalg.distributed.RowMatrix]] representing the product, * which preserves partitioning - * @since 1.0.0 */ + @Since("1.0.0") def multiply(B: Matrix): RowMatrix = { val n = numCols().toInt val k = B.numCols @@ -448,8 +445,8 @@ class RowMatrix( * * @return An n x n sparse upper-triangular matrix of cosine similarities between * columns of this matrix. - * @since 1.2.0 */ + @Since("1.2.0") def columnSimilarities(): CoordinateMatrix = { columnSimilarities(0.0) } @@ -492,8 +489,8 @@ class RowMatrix( * with the cost vs estimate quality trade-off described above. * @return An n x n sparse upper-triangular matrix of cosine similarities * between columns of this matrix. - * @since 1.2.0 */ + @Since("1.2.0") def columnSimilarities(threshold: Double): CoordinateMatrix = { require(threshold >= 0, s"Threshold cannot be negative: $threshold") @@ -671,9 +668,7 @@ class RowMatrix( } } -/** - * @since 1.0.0 - */ +@Since("1.0.0") @Experimental object RowMatrix { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 56c549ef99cb7..b27ef1b949e2e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.recommendation import org.apache.spark.Logging -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.ml.recommendation.{ALS => NewALS} import org.apache.spark.rdd.RDD @@ -26,8 +26,8 @@ import org.apache.spark.storage.StorageLevel /** * A more compact class to represent a rating than Tuple3[Int, Int, Double]. - * @since 0.8.0 */ +@Since("0.8.0") case class Rating(user: Int, product: Int, rating: Double) /** @@ -255,8 +255,8 @@ class ALS private ( /** * Top-level methods for calling Alternating Least Squares (ALS) matrix factorization. - * @since 0.8.0 */ +@Since("0.8.0") object ALS { /** * Train a matrix factorization model given an RDD of ratings given by users to some products, @@ -271,8 +271,8 @@ object ALS { * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into * @param seed random seed - * @since 0.9.1 */ + @Since("0.9.1") def train( ratings: RDD[Rating], rank: Int, @@ -296,8 +296,8 @@ object ALS { * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into - * @since 0.8.0 */ + @Since("0.8.0") def train( ratings: RDD[Rating], rank: Int, @@ -319,8 +319,8 @@ object ALS { * @param rank number of features to use * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) - * @since 0.8.0 */ + @Since("0.8.0") def train(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double) : MatrixFactorizationModel = { train(ratings, rank, iterations, lambda, -1) @@ -336,8 +336,8 @@ object ALS { * @param ratings RDD of (userID, productID, rating) pairs * @param rank number of features to use * @param iterations number of iterations of ALS (recommended: 10-20) - * @since 0.8.0 */ + @Since("0.8.0") def train(ratings: RDD[Rating], rank: Int, iterations: Int) : MatrixFactorizationModel = { train(ratings, rank, iterations, 0.01, -1) @@ -357,8 +357,8 @@ object ALS { * @param blocks level of parallelism to split computation into * @param alpha confidence parameter * @param seed random seed - * @since 0.8.1 */ + @Since("0.8.1") def trainImplicit( ratings: RDD[Rating], rank: Int, @@ -384,8 +384,8 @@ object ALS { * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into * @param alpha confidence parameter - * @since 0.8.1 */ + @Since("0.8.1") def trainImplicit( ratings: RDD[Rating], rank: Int, @@ -409,8 +409,8 @@ object ALS { * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) * @param alpha confidence parameter - * @since 0.8.1 */ + @Since("0.8.1") def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double, alpha: Double) : MatrixFactorizationModel = { trainImplicit(ratings, rank, iterations, lambda, -1, alpha) @@ -427,8 +427,8 @@ object ALS { * @param ratings RDD of (userID, productID, rating) pairs * @param rank number of features to use * @param iterations number of iterations of ALS (recommended: 10-20) - * @since 0.8.1 */ + @Since("0.8.1") def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int) : MatrixFactorizationModel = { trainImplicit(ratings, rank, iterations, 0.01, -1, 1.0) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 261ca9cef0c5b..ba4cfdcd9f1dd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -30,6 +30,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.rdd.MLPairRDDFunctions._ @@ -49,8 +50,8 @@ import org.apache.spark.storage.StorageLevel * the features computed for this user. * @param productFeatures RDD of tuples where each tuple represents the productId * and the features computed for this product. - * @since 0.8.0 */ +@Since("0.8.0") class MatrixFactorizationModel( val rank: Int, val userFeatures: RDD[(Int, Array[Double])], @@ -74,9 +75,8 @@ class MatrixFactorizationModel( } } - /** Predict the rating of one user for one product. - * @since 0.8.0 - */ + /** Predict the rating of one user for one product. */ + @Since("0.8.0") def predict(user: Int, product: Int): Double = { val userVector = userFeatures.lookup(user).head val productVector = productFeatures.lookup(product).head @@ -114,8 +114,8 @@ class MatrixFactorizationModel( * * @param usersProducts RDD of (user, product) pairs. * @return RDD of Ratings. - * @since 0.9.0 */ + @Since("0.9.0") def predict(usersProducts: RDD[(Int, Int)]): RDD[Rating] = { // Previously the partitions of ratings are only based on the given products. // So if the usersProducts given for prediction contains only few products or @@ -146,8 +146,8 @@ class MatrixFactorizationModel( /** * Java-friendly version of [[MatrixFactorizationModel.predict]]. - * @since 1.2.0 */ + @Since("1.2.0") def predict(usersProducts: JavaPairRDD[JavaInteger, JavaInteger]): JavaRDD[Rating] = { predict(usersProducts.rdd.asInstanceOf[RDD[(Int, Int)]]).toJavaRDD() } @@ -162,8 +162,8 @@ class MatrixFactorizationModel( * by score, decreasing. The first returned is the one predicted to be most strongly * recommended to the user. The score is an opaque value that indicates how strongly * recommended the product is. - * @since 1.1.0 */ + @Since("1.1.0") def recommendProducts(user: Int, num: Int): Array[Rating] = MatrixFactorizationModel.recommend(userFeatures.lookup(user).head, productFeatures, num) .map(t => Rating(user, t._1, t._2)) @@ -179,8 +179,8 @@ class MatrixFactorizationModel( * by score, decreasing. The first returned is the one predicted to be most strongly * recommended to the product. The score is an opaque value that indicates how strongly * recommended the user is. - * @since 1.1.0 */ + @Since("1.1.0") def recommendUsers(product: Int, num: Int): Array[Rating] = MatrixFactorizationModel.recommend(productFeatures.lookup(product).head, userFeatures, num) .map(t => Rating(t._1, product, t._2)) @@ -199,8 +199,8 @@ class MatrixFactorizationModel( * @param sc Spark context used to save model data. * @param path Path specifying the directory in which to save this model. * If the directory already exists, this method throws an exception. - * @since 1.3.0 */ + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { MatrixFactorizationModel.SaveLoadV1_0.save(this, path) } @@ -212,8 +212,8 @@ class MatrixFactorizationModel( * @return [(Int, Array[Rating])] objects, where every tuple contains a userID and an array of * rating objects which contains the same userId, recommended productID and a "score" in the * rating field. Semantics of score is same as recommendProducts API - * @since 1.4.0 */ + @Since("1.4.0") def recommendProductsForUsers(num: Int): RDD[(Int, Array[Rating])] = { MatrixFactorizationModel.recommendForAll(rank, userFeatures, productFeatures, num).map { case (user, top) => @@ -230,8 +230,8 @@ class MatrixFactorizationModel( * @return [(Int, Array[Rating])] objects, where every tuple contains a productID and an array * of rating objects which contains the recommended userId, same productID and a "score" in the * rating field. Semantics of score is same as recommendUsers API - * @since 1.4.0 */ + @Since("1.4.0") def recommendUsersForProducts(num: Int): RDD[(Int, Array[Rating])] = { MatrixFactorizationModel.recommendForAll(rank, productFeatures, userFeatures, num).map { case (product, top) => @@ -241,9 +241,7 @@ class MatrixFactorizationModel( } } -/** - * @since 1.3.0 - */ +@Since("1.3.0") object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { import org.apache.spark.mllib.util.Loader._ @@ -326,8 +324,8 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { * @param sc Spark context used for loading model files. * @param path Path specifying the directory to which the model was saved. * @return Model instance - * @since 1.3.0 */ + @Since("1.3.0") override def load(sc: SparkContext, path: String): MatrixFactorizationModel = { val (loadedClassName, formatVersion, _) = loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 2980b94de35b0..509f6a2d169c4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.feature.StandardScaler import org.apache.spark.{Logging, SparkException} import org.apache.spark.rdd.RDD @@ -35,8 +35,8 @@ import org.apache.spark.storage.StorageLevel * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. * - * @since 0.8.0 */ +@Since("0.8.0") @DeveloperApi abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double) extends Serializable { @@ -56,8 +56,8 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double * @param testData RDD representing data points to be predicted * @return RDD[Double] where each entry contains the corresponding prediction * - * @since 1.0.0 */ + @Since("1.0.0") def predict(testData: RDD[Vector]): RDD[Double] = { // A small optimization to avoid serializing the entire model. Only the weightsMatrix // and intercept is needed. @@ -76,8 +76,8 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double * @param testData array representing a single data point * @return Double prediction from the trained model * - * @since 1.0.0 */ + @Since("1.0.0") def predict(testData: Vector): Double = { predictPoint(testData, weights, intercept) } @@ -95,8 +95,8 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double * GeneralizedLinearAlgorithm implements methods to train a Generalized Linear Model (GLM). * This class should be extended with an Optimizer to create a new GLM. * - * @since 0.8.0 */ +@Since("0.8.0") @DeveloperApi abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] extends Logging with Serializable { @@ -106,8 +106,8 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] /** * The optimizer to solve the problem. * - * @since 1.0.0 */ + @Since("1.0.0") def optimizer: Optimizer /** Whether to add intercept (default: false). */ @@ -143,8 +143,8 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] /** * The dimension of training features. * - * @since 1.4.0 */ + @Since("1.4.0") def getNumFeatures: Int = this.numFeatures /** @@ -168,16 +168,16 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] /** * Get if the algorithm uses addIntercept * - * @since 1.4.0 */ + @Since("1.4.0") def isAddIntercept: Boolean = this.addIntercept /** * Set if the algorithm should add an intercept. Default false. * We set the default to false because adding the intercept will cause memory allocation. * - * @since 0.8.0 */ + @Since("0.8.0") def setIntercept(addIntercept: Boolean): this.type = { this.addIntercept = addIntercept this @@ -186,8 +186,8 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] /** * Set if the algorithm should validate data before training. Default true. * - * @since 0.8.0 */ + @Since("0.8.0") def setValidateData(validateData: Boolean): this.type = { this.validateData = validateData this @@ -197,8 +197,8 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * Run the algorithm with the configured parameters on an input * RDD of LabeledPoint entries. * - * @since 0.8.0 */ + @Since("0.8.0") def run(input: RDD[LabeledPoint]): M = { if (numFeatures < 0) { numFeatures = input.map(_.features.size).first() @@ -231,8 +231,8 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * Run the algorithm with the configured parameters on an input RDD * of LabeledPoint entries starting from the initial weights provided. * - * @since 1.0.0 */ + @Since("1.0.0") def run(input: RDD[LabeledPoint], initialWeights: Vector): M = { if (numFeatures < 0) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index 8995591d9e8ce..31ca7c2f207d9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -29,7 +29,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} @@ -47,8 +47,8 @@ import org.apache.spark.sql.SQLContext * Results of isotonic regression and therefore monotone. * @param isotonic indicates whether this is isotonic or antitonic. * - * @since 1.3.0 */ +@Since("1.3.0") @Experimental class IsotonicRegressionModel ( val boundaries: Array[Double], @@ -64,8 +64,8 @@ class IsotonicRegressionModel ( /** * A Java-friendly constructor that takes two Iterable parameters and one Boolean parameter. * - * @since 1.4.0 */ + @Since("1.4.0") def this(boundaries: java.lang.Iterable[Double], predictions: java.lang.Iterable[Double], isotonic: java.lang.Boolean) = { @@ -90,8 +90,8 @@ class IsotonicRegressionModel ( * @param testData Features to be labeled. * @return Predicted labels. * - * @since 1.3.0 */ + @Since("1.3.0") def predict(testData: RDD[Double]): RDD[Double] = { testData.map(predict) } @@ -103,8 +103,8 @@ class IsotonicRegressionModel ( * @param testData Features to be labeled. * @return Predicted labels. * - * @since 1.3.0 */ + @Since("1.3.0") def predict(testData: JavaDoubleRDD): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(predict(testData.rdd.retag.asInstanceOf[RDD[Double]])) } @@ -125,8 +125,8 @@ class IsotonicRegressionModel ( * as piecewise linear function and interpolated value is returned. In case there are * multiple values with the same boundary then the same rules as in 2) are used. * - * @since 1.3.0 */ + @Since("1.3.0") def predict(testData: Double): Double = { def linearInterpolation(x1: Double, y1: Double, x2: Double, y2: Double, x: Double): Double = { @@ -160,9 +160,7 @@ class IsotonicRegressionModel ( /** A convenient method for boundaries called by the Python API. */ private[mllib] def predictionVector: Vector = Vectors.dense(predictions) - /** - * @since 1.4.0 - */ + @Since("1.4.0") override def save(sc: SparkContext, path: String): Unit = { IsotonicRegressionModel.SaveLoadV1_0.save(sc, path, boundaries, predictions, isotonic) } @@ -170,9 +168,7 @@ class IsotonicRegressionModel ( override protected def formatVersion: String = "1.0" } -/** - * @since 1.4.0 - */ +@Since("1.4.0") object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { import org.apache.spark.mllib.util.Loader._ @@ -219,8 +215,8 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { } /** - * @since 1.4.0 */ + @Since("1.4.0") override def load(sc: SparkContext, path: String): IsotonicRegressionModel = { implicit val formats = DefaultFormats val (loadedClassName, version, metadata) = loadMetadata(sc, path) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index 8b51011eeb297..f7fe1b7b21fca 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.regression import scala.beans.BeanInfo +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.util.NumericParser import org.apache.spark.SparkException @@ -29,8 +30,8 @@ import org.apache.spark.SparkException * @param label Label for this data point. * @param features List of features for this data point. * - * @since 0.8.0 */ +@Since("0.8.0") @BeanInfo case class LabeledPoint(label: Double, features: Vector) { override def toString: String = { @@ -41,15 +42,15 @@ case class LabeledPoint(label: Double, features: Vector) { /** * Parser for [[org.apache.spark.mllib.regression.LabeledPoint]]. * - * @since 1.1.0 */ +@Since("1.1.0") object LabeledPoint { /** * Parses a string resulted from `LabeledPoint#toString` into * an [[org.apache.spark.mllib.regression.LabeledPoint]]. * - * @since 1.1.0 */ + @Since("1.1.0") def parse(s: String): LabeledPoint = { if (s.startsWith("(")) { NumericParser.parse(s) match { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index 03eb589b05a0e..556411a366bd2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.SparkContext +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable @@ -31,8 +32,8 @@ import org.apache.spark.rdd.RDD * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. * - * @since 0.8.0 */ +@Since("0.8.0") class LassoModel ( override val weights: Vector, override val intercept: Double) @@ -46,9 +47,7 @@ class LassoModel ( weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } - /** - * @since 1.3.0 - */ + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) } @@ -56,14 +55,10 @@ class LassoModel ( override protected def formatVersion: String = "1.0" } -/** - * @since 1.3.0 - */ +@Since("1.3.0") object LassoModel extends Loader[LassoModel] { - /** - * @since 1.3.0 - */ + @Since("1.3.0") override def load(sc: SparkContext, path: String): LassoModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) // Hard-code class name string in case it changes in the future @@ -118,8 +113,8 @@ class LassoWithSGD private ( /** * Top-level methods for calling Lasso. * - * @since 0.8.0 */ +@Since("0.8.0") object LassoWithSGD { /** @@ -137,8 +132,8 @@ object LassoWithSGD { * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. * - * @since 1.0.0 */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -162,8 +157,8 @@ object LassoWithSGD { * @param regParam Regularization parameter. * @param miniBatchFraction Fraction of data to be used per iteration. * - * @since 0.8.0 */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -185,8 +180,8 @@ object LassoWithSGD { * @param numIterations Number of iterations of gradient descent to run. * @return a LassoModel which has the weights and offset from training. * - * @since 0.8.0 */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -205,8 +200,8 @@ object LassoWithSGD { * @param numIterations Number of iterations of gradient descent to run. * @return a LassoModel which has the weights and offset from training. * - * @since 0.8.0 */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int): LassoModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index fb5c220daaedb..00ab06e3ba738 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.SparkContext +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable @@ -31,8 +32,8 @@ import org.apache.spark.rdd.RDD * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. * - * @since 0.8.0 */ +@Since("0.8.0") class LinearRegressionModel ( override val weights: Vector, override val intercept: Double) @@ -46,9 +47,7 @@ class LinearRegressionModel ( weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } - /** - * @since 1.3.0 - */ + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) } @@ -56,14 +55,10 @@ class LinearRegressionModel ( override protected def formatVersion: String = "1.0" } -/** - * @since 1.3.0 - */ +@Since("1.3.0") object LinearRegressionModel extends Loader[LinearRegressionModel] { - /** - * @since 1.3.0 - */ + @Since("1.3.0") override def load(sc: SparkContext, path: String): LinearRegressionModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) // Hard-code class name string in case it changes in the future @@ -117,8 +112,8 @@ class LinearRegressionWithSGD private[mllib] ( /** * Top-level methods for calling LinearRegression. * - * @since 0.8.0 */ +@Since("0.8.0") object LinearRegressionWithSGD { /** @@ -135,8 +130,8 @@ object LinearRegressionWithSGD { * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. * - * @since 1.0.0 */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -158,8 +153,8 @@ object LinearRegressionWithSGD { * @param stepSize Step size to be used for each iteration of gradient descent. * @param miniBatchFraction Fraction of data to be used per iteration. * - * @since 0.8.0 */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -179,8 +174,8 @@ object LinearRegressionWithSGD { * @param numIterations Number of iterations of gradient descent to run. * @return a LinearRegressionModel which has the weights and offset from training. * - * @since 0.8.0 */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -198,8 +193,8 @@ object LinearRegressionWithSGD { * @param numIterations Number of iterations of gradient descent to run. * @return a LinearRegressionModel which has the weights and offset from training. * - * @since 0.8.0 */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int): LinearRegressionModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala index b097fd38fdd82..0e72d6591ce83 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala @@ -19,14 +19,12 @@ package org.apache.spark.mllib.regression import org.json4s.{DefaultFormats, JValue} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD -/** - * @since 0.8.0 - */ +@Since("0.8.0") @Experimental trait RegressionModel extends Serializable { /** @@ -35,8 +33,8 @@ trait RegressionModel extends Serializable { * @param testData RDD representing data points to be predicted * @return RDD[Double] where each entry contains the corresponding prediction * - * @since 1.0.0 */ + @Since("1.0.0") def predict(testData: RDD[Vector]): RDD[Double] /** @@ -45,8 +43,8 @@ trait RegressionModel extends Serializable { * @param testData array representing a single data point * @return Double prediction from the trained model * - * @since 1.0.0 */ + @Since("1.0.0") def predict(testData: Vector): Double /** @@ -54,8 +52,8 @@ trait RegressionModel extends Serializable { * @param testData JavaRDD representing data points to be predicted * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction * - * @since 1.0.0 */ + @Since("1.0.0") def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 5bced6b4b7b53..21a791d98b2cb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.SparkContext +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable @@ -32,8 +33,8 @@ import org.apache.spark.rdd.RDD * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. * - * @since 0.8.0 */ +@Since("0.8.0") class RidgeRegressionModel ( override val weights: Vector, override val intercept: Double) @@ -47,9 +48,7 @@ class RidgeRegressionModel ( weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } - /** - * @since 1.3.0 - */ + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { GLMRegressionModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, weights, intercept) } @@ -57,14 +56,10 @@ class RidgeRegressionModel ( override protected def formatVersion: String = "1.0" } -/** - * @since 1.3.0 - */ +@Since("1.3.0") object RidgeRegressionModel extends Loader[RidgeRegressionModel] { - /** - * @since 1.3.0 - */ + @Since("1.3.0") override def load(sc: SparkContext, path: String): RidgeRegressionModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) // Hard-code class name string in case it changes in the future @@ -120,8 +115,8 @@ class RidgeRegressionWithSGD private ( /** * Top-level methods for calling RidgeRegression. * - * @since 0.8.0 */ +@Since("0.8.0") object RidgeRegressionWithSGD { /** @@ -138,8 +133,8 @@ object RidgeRegressionWithSGD { * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. * - * @since 0.8.0 */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -162,8 +157,8 @@ object RidgeRegressionWithSGD { * @param regParam Regularization parameter. * @param miniBatchFraction Fraction of data to be used per iteration. * - * @since 0.8.0 */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -184,8 +179,8 @@ object RidgeRegressionWithSGD { * @param numIterations Number of iterations of gradient descent to run. * @return a RidgeRegressionModel which has the weights and offset from training. * - * @since 0.8.0 */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int, @@ -203,8 +198,8 @@ object RidgeRegressionWithSGD { * @param numIterations Number of iterations of gradient descent to run. * @return a RidgeRegressionModel which has the weights and offset from training. * - * @since 0.8.0 */ + @Since("0.8.0") def train( input: RDD[LabeledPoint], numIterations: Int): RidgeRegressionModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index a2ab95c474765..cd3ed8a1549db 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.regression import scala.reflect.ClassTag import org.apache.spark.Logging -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.streaming.api.java.{JavaDStream, JavaPairDStream} @@ -54,8 +54,8 @@ import org.apache.spark.streaming.dstream.DStream * the model using each of the different sources, in sequence. * * - * @since 1.1.0 */ +@Since("1.1.0") @DeveloperApi abstract class StreamingLinearAlgorithm[ M <: GeneralizedLinearModel, @@ -70,8 +70,8 @@ abstract class StreamingLinearAlgorithm[ /** * Return the latest model. * - * @since 1.1.0 */ + @Since("1.1.0") def latestModel(): M = { model.get } @@ -84,8 +84,8 @@ abstract class StreamingLinearAlgorithm[ * * @param data DStream containing labeled data * - * @since 1.3.0 */ + @Since("1.3.0") def trainOn(data: DStream[LabeledPoint]): Unit = { if (model.isEmpty) { throw new IllegalArgumentException("Model must be initialized before starting training.") @@ -106,8 +106,8 @@ abstract class StreamingLinearAlgorithm[ /** * Java-friendly version of `trainOn`. * - * @since 1.3.0 */ + @Since("1.3.0") def trainOn(data: JavaDStream[LabeledPoint]): Unit = trainOn(data.dstream) /** @@ -116,8 +116,8 @@ abstract class StreamingLinearAlgorithm[ * @param data DStream containing feature vectors * @return DStream containing predictions * - * @since 1.1.0 */ + @Since("1.1.0") def predictOn(data: DStream[Vector]): DStream[Double] = { if (model.isEmpty) { throw new IllegalArgumentException("Model must be initialized before starting prediction.") @@ -128,8 +128,8 @@ abstract class StreamingLinearAlgorithm[ /** * Java-friendly version of `predictOn`. * - * @since 1.1.0 */ + @Since("1.1.0") def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Double] = { JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Double]]) } @@ -140,8 +140,8 @@ abstract class StreamingLinearAlgorithm[ * @tparam K key type * @return DStream containing the input keys and the predictions as values * - * @since 1.1.0 */ + @Since("1.1.0") def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Double)] = { if (model.isEmpty) { throw new IllegalArgumentException("Model must be initialized before starting prediction") @@ -153,8 +153,8 @@ abstract class StreamingLinearAlgorithm[ /** * Java-friendly version of `predictOnValues`. * - * @since 1.3.0 */ + @Since("1.3.0") def predictOnValues[K](data: JavaPairDStream[K, Vector]): JavaPairDStream[K, java.lang.Double] = { implicit val tag = fakeClassTag[K] JavaPairDStream.fromPairDStream( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala index 93a6753efd4d9..4a856f7f3434a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.stat import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD @@ -37,8 +37,8 @@ import org.apache.spark.rdd.RDD * .setBandwidth(3.0) * val densities = kd.estimate(Array(-1.0, 2.0, 5.0)) * }}} - * @since 1.4.0 */ +@Since("1.4.0") @Experimental class KernelDensity extends Serializable { @@ -52,8 +52,8 @@ class KernelDensity extends Serializable { /** * Sets the bandwidth (standard deviation) of the Gaussian kernel (default: `1.0`). - * @since 1.4.0 */ + @Since("1.4.0") def setBandwidth(bandwidth: Double): this.type = { require(bandwidth > 0, s"Bandwidth must be positive, but got $bandwidth.") this.bandwidth = bandwidth @@ -62,8 +62,8 @@ class KernelDensity extends Serializable { /** * Sets the sample to use for density estimation. - * @since 1.4.0 */ + @Since("1.4.0") def setSample(sample: RDD[Double]): this.type = { this.sample = sample this @@ -71,8 +71,8 @@ class KernelDensity extends Serializable { /** * Sets the sample to use for density estimation (for Java users). - * @since 1.4.0 */ + @Since("1.4.0") def setSample(sample: JavaRDD[java.lang.Double]): this.type = { this.sample = sample.rdd.asInstanceOf[RDD[Double]] this @@ -80,8 +80,8 @@ class KernelDensity extends Serializable { /** * Estimates probability density function at the given array of points. - * @since 1.4.0 */ + @Since("1.4.0") def estimate(points: Array[Double]): Array[Double] = { val sample = this.sample val bandwidth = this.bandwidth diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 64e4be0ebb97e..51b713e263e0c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.stat -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.{Vectors, Vector} /** @@ -33,8 +33,8 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector} * Reference: [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]] * Zero elements (including explicit zero values) are skipped when calling add(), * to have time complexity O(nnz) instead of O(n) for each column. - * @since 1.1.0 */ +@Since("1.1.0") @DeveloperApi class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable { @@ -53,8 +53,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S * * @param sample The sample in dense/sparse vector format to be added into this summarizer. * @return This MultivariateOnlineSummarizer object. - * @since 1.1.0 */ + @Since("1.1.0") def add(sample: Vector): this.type = { if (n == 0) { require(sample.size > 0, s"Vector should have dimension larger than zero.") @@ -109,8 +109,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S * * @param other The other MultivariateOnlineSummarizer to be merged. * @return This MultivariateOnlineSummarizer object. - * @since 1.1.0 */ + @Since("1.1.0") def merge(other: MultivariateOnlineSummarizer): this.type = { if (this.totalCnt != 0 && other.totalCnt != 0) { require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " + @@ -155,8 +155,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S /** * Sample mean of each dimension. * - * @since 1.1.0 */ + @Since("1.1.0") override def mean: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -172,8 +172,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S /** * Sample variance of each dimension. * - * @since 1.1.0 */ + @Since("1.1.0") override def variance: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -199,15 +199,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S /** * Sample size. * - * @since 1.1.0 */ + @Since("1.1.0") override def count: Long = totalCnt /** * Number of nonzero elements in each dimension. * - * @since 1.1.0 */ + @Since("1.1.0") override def numNonzeros: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -217,8 +217,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S /** * Maximum value of each dimension. * - * @since 1.1.0 */ + @Since("1.1.0") override def max: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -233,8 +233,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S /** * Minimum value of each dimension. * - * @since 1.1.0 */ + @Since("1.1.0") override def min: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -249,8 +249,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S /** * L2 (Euclidian) norm of each dimension. * - * @since 1.2.0 */ + @Since("1.2.0") override def normL2: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -268,8 +268,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S /** * L1 norm of each dimension. * - * @since 1.2.0 */ + @Since("1.2.0") override def normL1: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala index 3bb49f12289e1..39a16fb743d64 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala @@ -17,59 +17,60 @@ package org.apache.spark.mllib.stat +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.Vector /** * Trait for multivariate statistical summary of a data matrix. - * @since 1.0.0 */ +@Since("1.0.0") trait MultivariateStatisticalSummary { /** * Sample mean vector. - * @since 1.0.0 */ + @Since("1.0.0") def mean: Vector /** * Sample variance vector. Should return a zero vector if the sample size is 1. - * @since 1.0.0 */ + @Since("1.0.0") def variance: Vector /** * Sample size. - * @since 1.0.0 */ + @Since("1.0.0") def count: Long /** * Number of nonzero elements (including explicitly presented zero values) in each column. - * @since 1.0.0 */ + @Since("1.0.0") def numNonzeros: Vector /** * Maximum value of each column. - * @since 1.0.0 */ + @Since("1.0.0") def max: Vector /** * Minimum value of each column. - * @since 1.0.0 */ + @Since("1.0.0") def min: Vector /** * Euclidean magnitude of each column - * @since 1.2.0 */ + @Since("1.2.0") def normL2: Vector /** * L1 norm of each column - * @since 1.2.0 */ + @Since("1.2.0") def normL1: Vector } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index ef8d78607048f..84d64a5bfb38e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.stat import scala.annotation.varargs -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.{JavaRDD, JavaDoubleRDD} import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.linalg.{Matrix, Vector} @@ -32,8 +32,8 @@ import org.apache.spark.rdd.RDD /** * :: Experimental :: * API for statistical functions in MLlib. - * @since 1.1.0 */ +@Since("1.1.0") @Experimental object Statistics { @@ -42,8 +42,8 @@ object Statistics { * * @param X an RDD[Vector] for which column-wise summary statistics are to be computed. * @return [[MultivariateStatisticalSummary]] object containing column-wise summary statistics. - * @since 1.1.0 */ + @Since("1.1.0") def colStats(X: RDD[Vector]): MultivariateStatisticalSummary = { new RowMatrix(X).computeColumnSummaryStatistics() } @@ -54,8 +54,8 @@ object Statistics { * * @param X an RDD[Vector] for which the correlation matrix is to be computed. * @return Pearson correlation matrix comparing columns in X. - * @since 1.1.0 */ + @Since("1.1.0") def corr(X: RDD[Vector]): Matrix = Correlations.corrMatrix(X) /** @@ -71,8 +71,8 @@ object Statistics { * @param method String specifying the method to use for computing correlation. * Supported: `pearson` (default), `spearman` * @return Correlation matrix comparing columns in X. - * @since 1.1.0 */ + @Since("1.1.0") def corr(X: RDD[Vector], method: String): Matrix = Correlations.corrMatrix(X, method) /** @@ -85,14 +85,14 @@ object Statistics { * @param x RDD[Double] of the same cardinality as y. * @param y RDD[Double] of the same cardinality as x. * @return A Double containing the Pearson correlation between the two input RDD[Double]s - * @since 1.1.0 */ + @Since("1.1.0") def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) /** * Java-friendly version of [[corr()]] - * @since 1.4.1 */ + @Since("1.4.1") def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double]): Double = corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]]) @@ -109,14 +109,14 @@ object Statistics { * Supported: `pearson` (default), `spearman` * @return A Double containing the correlation between the two input RDD[Double]s using the * specified method. - * @since 1.1.0 */ + @Since("1.1.0") def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) /** * Java-friendly version of [[corr()]] - * @since 1.4.1 */ + @Since("1.4.1") def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double], method: String): Double = corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]], method) @@ -133,8 +133,8 @@ object Statistics { * `expected` is rescaled if the `expected` sum differs from the `observed` sum. * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. - * @since 1.1.0 */ + @Since("1.1.0") def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = { ChiSqTest.chiSquared(observed, expected) } @@ -148,8 +148,8 @@ object Statistics { * @param observed Vector containing the observed categorical counts/relative frequencies. * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. - * @since 1.1.0 */ + @Since("1.1.0") def chiSqTest(observed: Vector): ChiSqTestResult = ChiSqTest.chiSquared(observed) /** @@ -159,8 +159,8 @@ object Statistics { * @param observed The contingency matrix (containing either counts or relative frequencies). * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. - * @since 1.1.0 */ + @Since("1.1.0") def chiSqTest(observed: Matrix): ChiSqTestResult = ChiSqTest.chiSquaredMatrix(observed) /** @@ -172,13 +172,14 @@ object Statistics { * Real-valued features will be treated as categorical for each distinct value. * @return an array containing the ChiSquaredTestResult for every feature against the label. * The order of the elements in the returned array reflects the order of input features. - * @since 1.1.0 */ + @Since("1.1.0") def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = { ChiSqTest.chiSquaredFeatures(data) } /** Java-friendly version of [[chiSqTest()]] */ + @Since("1.5.0") def chiSqTest(data: JavaRDD[LabeledPoint]): Array[ChiSqTestResult] = chiSqTest(data.rdd) /** @@ -194,6 +195,7 @@ object Statistics { * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] object containing test * statistic, p-value, and null hypothesis. */ + @Since("1.5.0") def kolmogorovSmirnovTest(data: RDD[Double], cdf: Double => Double) : KolmogorovSmirnovTestResult = { KolmogorovSmirnovTest.testOneSample(data, cdf) @@ -210,6 +212,7 @@ object Statistics { * @return [[org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult]] object containing test * statistic, p-value, and null hypothesis. */ + @Since("1.5.0") @varargs def kolmogorovSmirnovTest(data: RDD[Double], distName: String, params: Double*) : KolmogorovSmirnovTestResult = { @@ -217,6 +220,7 @@ object Statistics { } /** Java-friendly version of [[kolmogorovSmirnovTest()]] */ + @Since("1.5.0") @varargs def kolmogorovSmirnovTest( data: JavaDoubleRDD, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index 9aa7763d7890d..bd4d81390bfae 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.stat.distribution import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym, Vector => BV} -import org.apache.spark.annotation.DeveloperApi; +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix} import org.apache.spark.mllib.util.MLUtils @@ -32,8 +32,8 @@ import org.apache.spark.mllib.util.MLUtils * * @param mu The mean vector of the distribution * @param sigma The covariance matrix of the distribution - * @since 1.3.0 */ +@Since("1.3.0") @DeveloperApi class MultivariateGaussian ( val mu: Vector, @@ -62,15 +62,15 @@ class MultivariateGaussian ( private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants /** Returns density of this multivariate Gaussian at given point, x - * @since 1.3.0 */ + @Since("1.3.0") def pdf(x: Vector): Double = { pdf(x.toBreeze) } /** Returns the log-density of this multivariate Gaussian at given point, x - * @since 1.3.0 */ + @Since("1.3.0") def logpdf(x: Vector): Double = { logpdf(x.toBreeze) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index e5200b86fddd4..972841015d4f0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuilder import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo @@ -43,8 +43,8 @@ import org.apache.spark.util.random.XORShiftRandom * @param strategy The configuration parameters for the tree algorithm which specify the type * of algorithm (classification, regression, etc.), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. - * @since 1.0.0 */ +@Since("1.0.0") @Experimental class DecisionTree (private val strategy: Strategy) extends Serializable with Logging { @@ -54,8 +54,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo * Method to train a decision tree model over an RDD * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @return DecisionTreeModel that can be used for prediction - * @since 1.2.0 */ + @Since("1.2.0") def run(input: RDD[LabeledPoint]): DecisionTreeModel = { // Note: random seed will not be used since numTrees = 1. val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0) @@ -64,9 +64,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo } } -/** - * @since 1.0.0 - */ +@Since("1.0.0") object DecisionTree extends Serializable with Logging { /** @@ -84,8 +82,8 @@ object DecisionTree extends Serializable with Logging { * of algorithm (classification, regression, etc.), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. * @return DecisionTreeModel that can be used for prediction - * @since 1.0.0 - */ + */ + @Since("1.0.0") def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { new DecisionTree(strategy).run(input) } @@ -106,8 +104,8 @@ object DecisionTree extends Serializable with Logging { * @param maxDepth Maximum depth of the tree. * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @return DecisionTreeModel that can be used for prediction - * @since 1.0.0 */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], algo: Algo, @@ -134,8 +132,8 @@ object DecisionTree extends Serializable with Logging { * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. * @param numClasses number of classes for classification. Default value of 2. * @return DecisionTreeModel that can be used for prediction - * @since 1.2.0 */ + @Since("1.2.0") def train( input: RDD[LabeledPoint], algo: Algo, @@ -168,8 +166,8 @@ object DecisionTree extends Serializable with Logging { * E.g., an entry (n -> k) indicates that feature n is categorical * with k categories indexed from 0: {0, 1, ..., k-1}. * @return DecisionTreeModel that can be used for prediction - * @since 1.0.0 */ + @Since("1.0.0") def train( input: RDD[LabeledPoint], algo: Algo, @@ -201,8 +199,8 @@ object DecisionTree extends Serializable with Logging { * @param maxBins maximum number of bins used for splitting features * (suggested value: 32) * @return DecisionTreeModel that can be used for prediction - * @since 1.1.0 */ + @Since("1.1.0") def trainClassifier( input: RDD[LabeledPoint], numClasses: Int, @@ -217,8 +215,8 @@ object DecisionTree extends Serializable with Logging { /** * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] - * @since 1.1.0 */ + @Since("1.1.0") def trainClassifier( input: JavaRDD[LabeledPoint], numClasses: Int, @@ -247,8 +245,8 @@ object DecisionTree extends Serializable with Logging { * @param maxBins maximum number of bins used for splitting features * (suggested value: 32) * @return DecisionTreeModel that can be used for prediction - * @since 1.1.0 */ + @Since("1.1.0") def trainRegressor( input: RDD[LabeledPoint], categoricalFeaturesInfo: Map[Int, Int], @@ -261,8 +259,8 @@ object DecisionTree extends Serializable with Logging { /** * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] - * @since 1.1.0 */ + @Since("1.1.0") def trainRegressor( input: JavaRDD[LabeledPoint], categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 143617098637a..e750408600c33 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.tree import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer import org.apache.spark.mllib.regression.LabeledPoint @@ -48,8 +48,8 @@ import org.apache.spark.storage.StorageLevel * for other loss functions. * * @param boostingStrategy Parameters for the gradient boosting algorithm. - * @since 1.2.0 */ +@Since("1.2.0") @Experimental class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) extends Serializable with Logging { @@ -58,8 +58,8 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) * Method to train a gradient boosting model * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @return a gradient boosted trees model that can be used for prediction - * @since 1.2.0 */ + @Since("1.2.0") def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { val algo = boostingStrategy.treeStrategy.algo algo match { @@ -76,8 +76,8 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) /** * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#run]]. - * @since 1.2.0 */ + @Since("1.2.0") def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { run(input.rdd) } @@ -91,8 +91,8 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) * E.g., these two datasets could be created from an original dataset * by using [[org.apache.spark.rdd.RDD.randomSplit()]] * @return a gradient boosted trees model that can be used for prediction - * @since 1.4.0 */ + @Since("1.4.0") def runWithValidation( input: RDD[LabeledPoint], validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = { @@ -115,8 +115,8 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) /** * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#runWithValidation]]. - * @since 1.4.0 */ + @Since("1.4.0") def runWithValidation( input: JavaRDD[LabeledPoint], validationInput: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { @@ -124,9 +124,7 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) } } -/** - * @since 1.2.0 - */ +@Since("1.2.0") object GradientBoostedTrees extends Logging { /** @@ -137,8 +135,8 @@ object GradientBoostedTrees extends Logging { * For regression, labels are real numbers. * @param boostingStrategy Configuration options for the boosting algorithm. * @return a gradient boosted trees model that can be used for prediction - * @since 1.2.0 */ + @Since("1.2.0") def train( input: RDD[LabeledPoint], boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { @@ -147,8 +145,8 @@ object GradientBoostedTrees extends Logging { /** * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees$#train]] - * @since 1.2.0 */ + @Since("1.2.0") def train( input: JavaRDD[LabeledPoint], boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 9f3230656acc5..63a902f3eb51e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import scala.collection.JavaConverters._ import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Strategy @@ -260,9 +260,7 @@ private class RandomForest ( } -/** - * @since 1.2.0 - */ +@Since("1.2.0") object RandomForest extends Serializable with Logging { /** @@ -279,8 +277,8 @@ object RandomForest extends Serializable with Logging { * if numTrees > 1 (forest) set to "sqrt". * @param seed Random seed for bootstrapping and choosing feature subsets. * @return a random forest model that can be used for prediction - * @since 1.2.0 */ + @Since("1.2.0") def trainClassifier( input: RDD[LabeledPoint], strategy: Strategy, @@ -317,8 +315,8 @@ object RandomForest extends Serializable with Logging { * (suggested value: 100) * @param seed Random seed for bootstrapping and choosing feature subsets. * @return a random forest model that can be used for prediction - * @since 1.2.0 */ + @Since("1.2.0") def trainClassifier( input: RDD[LabeledPoint], numClasses: Int, @@ -337,8 +335,8 @@ object RandomForest extends Serializable with Logging { /** * Java-friendly API for [[org.apache.spark.mllib.tree.RandomForest$#trainClassifier]] - * @since 1.2.0 */ + @Since("1.2.0") def trainClassifier( input: JavaRDD[LabeledPoint], numClasses: Int, @@ -368,8 +366,8 @@ object RandomForest extends Serializable with Logging { * if numTrees > 1 (forest) set to "onethird". * @param seed Random seed for bootstrapping and choosing feature subsets. * @return a random forest model that can be used for prediction - * @since 1.2.0 */ + @Since("1.2.0") def trainRegressor( input: RDD[LabeledPoint], strategy: Strategy, @@ -405,8 +403,8 @@ object RandomForest extends Serializable with Logging { * (suggested value: 100) * @param seed Random seed for bootstrapping and choosing feature subsets. * @return a random forest model that can be used for prediction - * @since 1.2.0 */ + @Since("1.2.0") def trainRegressor( input: RDD[LabeledPoint], categoricalFeaturesInfo: Map[Int, Int], @@ -424,8 +422,8 @@ object RandomForest extends Serializable with Logging { /** * Java-friendly API for [[org.apache.spark.mllib.tree.RandomForest$#trainRegressor]] - * @since 1.2.0 */ + @Since("1.2.0") def trainRegressor( input: JavaRDD[LabeledPoint], categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], @@ -442,8 +440,8 @@ object RandomForest extends Serializable with Logging { /** * List of supported feature subset sampling strategies. - * @since 1.2.0 */ + @Since("1.2.0") val supportedFeatureSubsetStrategies: Array[String] = Array("auto", "all", "sqrt", "log2", "onethird") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala index d9a49aa71fcfb..8301ad160836b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala @@ -17,13 +17,13 @@ package org.apache.spark.mllib.tree.configuration -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} /** * :: Experimental :: * Enum to select the algorithm for the decision tree - * @since 1.0.0 */ +@Since("1.0.0") @Experimental object Algo extends Enumeration { type Algo = Value diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index 88e5f57e9ab32..7c569981977b4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.tree.configuration import scala.beans.BeanProperty -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} @@ -38,8 +38,8 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} * validation input between two iterations is less than the validationTol * then stop. Ignored when * [[org.apache.spark.mllib.tree.GradientBoostedTrees.run()]] is used. - * @since 1.2.0 */ +@Since("1.2.0") @Experimental case class BoostingStrategy( // Required boosting parameters @@ -71,9 +71,7 @@ case class BoostingStrategy( } } -/** - * @since 1.2.0 - */ +@Since("1.2.0") @Experimental object BoostingStrategy { @@ -81,8 +79,8 @@ object BoostingStrategy { * Returns default configuration for the boosting algorithm * @param algo Learning goal. Supported: "Classification" or "Regression" * @return Configuration for boosting algorithm - * @since 1.2.0 */ + @Since("1.2.0") def defaultParams(algo: String): BoostingStrategy = { defaultParams(Algo.fromString(algo)) } @@ -93,8 +91,8 @@ object BoostingStrategy { * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] * @return Configuration for boosting algorithm - * @since 1.3.0 */ + @Since("1.3.0") def defaultParams(algo: Algo): BoostingStrategy = { val treeStrategy = Strategy.defaultStrategy(algo) treeStrategy.maxDepth = 3 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala index 0684cafa486bd..bb7c7ee4f964f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala @@ -17,13 +17,13 @@ package org.apache.spark.mllib.tree.configuration -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} /** * :: Experimental :: * Enum to describe whether a feature is "continuous" or "categorical" - * @since 1.0.0 */ +@Since("1.0.0") @Experimental object FeatureType extends Enumeration { type FeatureType = Value diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala index 2daa63c4d2771..904e42deebb5f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala @@ -17,13 +17,13 @@ package org.apache.spark.mllib.tree.configuration -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} /** * :: Experimental :: * Enum for selecting the quantile calculation strategy - * @since 1.0.0 */ +@Since("1.0.0") @Experimental object QuantileStrategy extends Enumeration { type QuantileStrategy = Value diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 7ae25a88bf500..a58f01ba8544e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.configuration import scala.beans.BeanProperty import scala.collection.JavaConverters._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ @@ -66,8 +66,8 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * E.g. 10 means that the cache will get checkpointed every 10 updates. If * the checkpoint directory is not set in * [[org.apache.spark.SparkContext]], this setting is ignored. - * @since 1.0.0 */ +@Since("1.0.0") @Experimental class Strategy ( @BeanProperty var algo: Algo, @@ -85,23 +85,23 @@ class Strategy ( @BeanProperty var checkpointInterval: Int = 10) extends Serializable { /** - * @since 1.2.0 */ + @Since("1.2.0") def isMulticlassClassification: Boolean = { algo == Classification && numClasses > 2 } /** - * @since 1.2.0 */ + @Since("1.2.0") def isMulticlassWithCategoricalFeatures: Boolean = { isMulticlassClassification && (categoricalFeaturesInfo.size > 0) } /** * Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]] - * @since 1.1.0 */ + @Since("1.1.0") def this( algo: Algo, impurity: Impurity, @@ -115,8 +115,8 @@ class Strategy ( /** * Sets Algorithm using a String. - * @since 1.2.0 */ + @Since("1.2.0") def setAlgo(algo: String): Unit = algo match { case "Classification" => setAlgo(Classification) case "Regression" => setAlgo(Regression) @@ -124,8 +124,8 @@ class Strategy ( /** * Sets categoricalFeaturesInfo using a Java Map. - * @since 1.2.0 */ + @Since("1.2.0") def setCategoricalFeaturesInfo( categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]): Unit = { this.categoricalFeaturesInfo = @@ -174,8 +174,8 @@ class Strategy ( /** * Returns a shallow copy of this instance. - * @since 1.2.0 */ + @Since("1.2.0") def copy: Strategy = { new Strategy(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, @@ -183,17 +183,15 @@ class Strategy ( } } -/** - * @since 1.2.0 - */ +@Since("1.2.0") @Experimental object Strategy { /** * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] * @param algo "Classification" or "Regression" - * @since 1.2.0 */ + @Since("1.2.0") def defaultStrategy(algo: String): Strategy = { defaultStrategy(Algo.fromString(algo)) } @@ -201,8 +199,8 @@ object Strategy { /** * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] * @param algo Algo.Classification or Algo.Regression - * @since 1.3.0 */ + @Since("1.3.0") def defaultStrategy(algo: Algo): Strategy = algo match { case Algo.Classification => new Strategy(algo = Classification, impurity = Gini, maxDepth = 10, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 0b6c7266dee05..73df6b054a8ce 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -17,14 +17,14 @@ package org.apache.spark.mllib.tree.impurity -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} /** * :: Experimental :: * Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during * binary classification. - * @since 1.0.0 */ +@Since("1.0.0") @Experimental object Entropy extends Impurity { @@ -36,8 +36,8 @@ object Entropy extends Impurity { * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels * @return information value, or 0 if totalCount = 0 - * @since 1.1.0 */ + @Since("1.1.0") @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = { if (totalCount == 0) { @@ -64,8 +64,8 @@ object Entropy extends Impurity { * @param sum sum of labels * @param sumSquares summation of squares of the labels * @return information value, or 0 if count = 0 - * @since 1.0.0 */ + @Since("1.0.0") @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("Entropy.calculate") @@ -73,8 +73,8 @@ object Entropy extends Impurity { /** * Get this impurity instance. * This is useful for passing impurity parameters to a Strategy in Java. - * @since 1.1.0 */ + @Since("1.1.0") def instance: this.type = this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 3b0be428833ae..f21845b21a802 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -17,15 +17,15 @@ package org.apache.spark.mllib.tree.impurity -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} /** * :: Experimental :: * Class for calculating the * [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]] * during binary classification. - * @since 1.0.0 */ +@Since("1.0.0") @Experimental object Gini extends Impurity { @@ -35,8 +35,8 @@ object Gini extends Impurity { * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels * @return information value, or 0 if totalCount = 0 - * @since 1.1.0 */ + @Since("1.1.0") @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = { if (totalCount == 0) { @@ -60,8 +60,8 @@ object Gini extends Impurity { * @param sum sum of labels * @param sumSquares summation of squares of the labels * @return information value, or 0 if count = 0 - * @since 1.0.0 */ + @Since("1.0.0") @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("Gini.calculate") @@ -69,8 +69,8 @@ object Gini extends Impurity { /** * Get this impurity instance. * This is useful for passing impurity parameters to a Strategy in Java. - * @since 1.1.0 */ + @Since("1.1.0") def instance: this.type = this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index dd297400058d2..4637dcceea7f8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.impurity -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} /** * :: Experimental :: @@ -25,8 +25,8 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} * This trait is used for * (a) setting the impurity parameter in [[org.apache.spark.mllib.tree.configuration.Strategy]] * (b) calculating impurity values from sufficient statistics. - * @since 1.0.0 */ +@Since("1.0.0") @Experimental trait Impurity extends Serializable { @@ -36,8 +36,8 @@ trait Impurity extends Serializable { * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels * @return information value, or 0 if totalCount = 0 - * @since 1.1.0 */ + @Since("1.1.0") @DeveloperApi def calculate(counts: Array[Double], totalCount: Double): Double @@ -48,8 +48,8 @@ trait Impurity extends Serializable { * @param sum sum of labels * @param sumSquares summation of squares of the labels * @return information value, or 0 if count = 0 - * @since 1.0.0 */ + @Since("1.0.0") @DeveloperApi def calculate(count: Double, sum: Double, sumSquares: Double): Double } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index adbe05811f286..a74197278d6f7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -17,13 +17,13 @@ package org.apache.spark.mllib.tree.impurity -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} /** * :: Experimental :: * Class for calculating variance during regression - * @since 1.0.0 */ +@Since("1.0.0") @Experimental object Variance extends Impurity { @@ -33,8 +33,8 @@ object Variance extends Impurity { * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels * @return information value, or 0 if totalCount = 0 - * @since 1.1.0 */ + @Since("1.1.0") @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = throw new UnsupportedOperationException("Variance.calculate") @@ -46,8 +46,8 @@ object Variance extends Impurity { * @param sum sum of labels * @param sumSquares summation of squares of the labels * @return information value, or 0 if count = 0 - * @since 1.0.0 */ + @Since("1.0.0") @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = { if (count == 0) { @@ -60,8 +60,8 @@ object Variance extends Impurity { /** * Get this impurity instance. * This is useful for passing impurity parameters to a Strategy in Java. - * @since 1.0.0 */ + @Since("1.0.0") def instance: this.type = this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala index c6e3d0d824dd7..bab7b8c6cadf2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.loss -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel @@ -29,8 +29,8 @@ import org.apache.spark.mllib.tree.model.TreeEnsembleModel * The absolute (L1) error is defined as: * |y - F(x)| * where y is the label and F(x) is the model prediction for features x. - * @since 1.2.0 */ +@Since("1.2.0") @DeveloperApi object AbsoluteError extends Loss { @@ -41,8 +41,8 @@ object AbsoluteError extends Loss { * @param prediction Predicted label. * @param label True label. * @return Loss gradient - * @since 1.2.0 */ + @Since("1.2.0") override def gradient(prediction: Double, label: Double): Double = { if (label - prediction < 0) 1.0 else -1.0 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index eee58445a1ec1..b2b4594712f0d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.loss -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.mllib.util.MLUtils @@ -31,8 +31,8 @@ import org.apache.spark.mllib.util.MLUtils * The log loss is defined as: * 2 log(1 + exp(-2 y F(x))) * where y is a label in {-1, 1} and F(x) is the model prediction for features x. - * @since 1.2.0 */ +@Since("1.2.0") @DeveloperApi object LogLoss extends Loss { @@ -43,8 +43,8 @@ object LogLoss extends Loss { * @param prediction Predicted label. * @param label True label. * @return Loss gradient - * @since 1.2.0 */ + @Since("1.2.0") override def gradient(prediction: Double, label: Double): Double = { - 4.0 * label / (1.0 + math.exp(2.0 * label * prediction)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index 7c9fb924645c8..687cde325ffed 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.loss -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD @@ -26,8 +26,8 @@ import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: * Trait for adding "pluggable" loss functions for the gradient boosting algorithm. - * @since 1.2.0 */ +@Since("1.2.0") @DeveloperApi trait Loss extends Serializable { @@ -36,8 +36,8 @@ trait Loss extends Serializable { * @param prediction Predicted feature * @param label true label. * @return Loss gradient. - * @since 1.2.0 */ + @Since("1.2.0") def gradient(prediction: Double, label: Double): Double /** @@ -47,8 +47,8 @@ trait Loss extends Serializable { * @param model Model of the weak learner. * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @return Measure of model error on data - * @since 1.2.0 */ + @Since("1.2.0") def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { data.map(point => computeError(model.predict(point.features), point.label)).mean() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala index 47dc94cde7e03..2b112fbe12202 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala @@ -17,14 +17,12 @@ package org.apache.spark.mllib.tree.loss -/** - * @since 1.2.0 - */ +import org.apache.spark.annotation.Since + +@Since("1.2.0") object Losses { - /** - * @since 1.2.0 - */ + @Since("1.2.0") def fromString(name: String): Loss = name match { case "leastSquaresError" => SquaredError case "leastAbsoluteError" => AbsoluteError diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala index ff8903d6956bd..3f7d3d38be16c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.loss -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel @@ -29,8 +29,8 @@ import org.apache.spark.mllib.tree.model.TreeEnsembleModel * The squared (L2) error is defined as: * (y - F(x))**2 * where y is the label and F(x) is the model prediction for features x. - * @since 1.2.0 */ +@Since("1.2.0") @DeveloperApi object SquaredError extends Loss { @@ -41,8 +41,8 @@ object SquaredError extends Loss { * @param prediction Predicted label. * @param label True label. * @return Loss gradient - * @since 1.2.0 */ + @Since("1.2.0") override def gradient(prediction: Double, label: Double): Double = { - 2.0 * (label - prediction) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 0f386a26601ce..3eefd135f7836 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -24,7 +24,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.tree.configuration.{Algo, FeatureType} @@ -40,8 +40,8 @@ import org.apache.spark.util.Utils * This model stores the decision tree structure and parameters. * @param topNode root node * @param algo algorithm type -- classification or regression - * @since 1.0.0 */ +@Since("1.0.0") @Experimental class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable with Saveable { @@ -50,8 +50,8 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * * @param features array representing a single data point * @return Double prediction from the trained model - * @since 1.0.0 */ + @Since("1.0.0") def predict(features: Vector): Double = { topNode.predict(features) } @@ -61,8 +61,8 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * * @param features RDD representing data points to be predicted * @return RDD of predictions for each of the given data points - * @since 1.0.0 */ + @Since("1.0.0") def predict(features: RDD[Vector]): RDD[Double] = { features.map(x => predict(x)) } @@ -72,16 +72,16 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * * @param features JavaRDD representing data points to be predicted * @return JavaRDD of predictions for each of the given data points - * @since 1.2.0 */ + @Since("1.2.0") def predict(features: JavaRDD[Vector]): JavaRDD[Double] = { predict(features.rdd) } /** * Get number of nodes in tree, including leaf nodes. - * @since 1.1.0 */ + @Since("1.1.0") def numNodes: Int = { 1 + topNode.numDescendants } @@ -89,8 +89,8 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable /** * Get depth of tree. * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. - * @since 1.1.0 */ + @Since("1.1.0") def depth: Int = { topNode.subtreeDepth } @@ -119,8 +119,8 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * @param sc Spark context used to save model data. * @param path Path specifying the directory in which to save this model. * If the directory already exists, this method throws an exception. - * @since 1.3.0 */ + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { DecisionTreeModel.SaveLoadV1_0.save(sc, path, this) } @@ -128,9 +128,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable override protected def formatVersion: String = DecisionTreeModel.formatVersion } -/** - * @since 1.3.0 - */ +@Since("1.3.0") object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { private[spark] def formatVersion: String = "1.0" @@ -317,8 +315,8 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { * @param sc Spark context used for loading model files. * @param path Path specifying the directory to which the model was saved. * @return Model instance - * @since 1.3.0 */ + @Since("1.3.0") override def load(sc: SparkContext, path: String): DecisionTreeModel = { implicit val formats = DefaultFormats val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index 23f0363639120..091a0462c204f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.model -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator /** @@ -29,8 +29,8 @@ import org.apache.spark.mllib.tree.impurity.ImpurityCalculator * @param rightImpurity right node impurity * @param leftPredict left node predict * @param rightPredict right node predict - * @since 1.0.0 */ +@Since("1.0.0") @DeveloperApi class InformationGainStats( val gain: Double, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index aca3350c2e535..8c54c55107233 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.model -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.Logging import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.linalg.Vector @@ -38,8 +38,8 @@ import org.apache.spark.mllib.linalg.Vector * @param leftNode left child * @param rightNode right child * @param stats information gain stats - * @since 1.0.0 */ +@Since("1.0.0") @DeveloperApi class Node ( val id: Int, @@ -59,8 +59,8 @@ class Node ( /** * build the left node and right nodes if not leaf * @param nodes array of nodes - * @since 1.0.0 */ + @Since("1.0.0") @deprecated("build should no longer be used since trees are constructed on-the-fly in training", "1.2.0") def build(nodes: Array[Node]): Unit = { @@ -81,8 +81,8 @@ class Node ( * predict value if node is not leaf * @param features feature value * @return predicted value - * @since 1.1.0 */ + @Since("1.1.0") def predict(features: Vector) : Double = { if (isLeaf) { predict.predict diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala index be819b59e7048..965784051ede5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -17,14 +17,14 @@ package org.apache.spark.mllib.tree.model -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} /** * Predicted value for a node * @param predict predicted value * @param prob probability of the label (classification only) - * @since 1.2.0 */ +@Since("1.2.0") @DeveloperApi class Predict( val predict: Double, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index 18d40530aee1e..45db83ae3a1f3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.tree.model -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType import org.apache.spark.mllib.tree.configuration.FeatureType import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType @@ -30,8 +30,8 @@ import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType * Split left if feature <= threshold, else right. * @param featureType type of feature -- categorical or continuous * @param categories Split left if categorical feature value is in this set, else right. - * @since 1.0.0 */ +@Since("1.0.0") @DeveloperApi case class Split( feature: Int, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 0c629b12a84df..19571447a2c56 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -25,7 +25,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -45,8 +45,8 @@ import org.apache.spark.util.Utils * * @param algo algorithm for the ensemble model, either Classification or Regression * @param trees tree ensembles - * @since 1.2.0 */ +@Since("1.2.0") @Experimental class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel]) extends TreeEnsembleModel(algo, trees, Array.fill(trees.length)(1.0), @@ -60,8 +60,8 @@ class RandomForestModel(override val algo: Algo, override val trees: Array[Decis * @param sc Spark context used to save model data. * @param path Path specifying the directory in which to save this model. * If the directory already exists, this method throws an exception. - * @since 1.3.0 */ + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this, RandomForestModel.SaveLoadV1_0.thisClassName) @@ -70,9 +70,7 @@ class RandomForestModel(override val algo: Algo, override val trees: Array[Decis override protected def formatVersion: String = RandomForestModel.formatVersion } -/** - * @since 1.3.0 - */ +@Since("1.3.0") object RandomForestModel extends Loader[RandomForestModel] { private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion @@ -82,8 +80,8 @@ object RandomForestModel extends Loader[RandomForestModel] { * @param sc Spark context used for loading model files. * @param path Path specifying the directory to which the model was saved. * @return Model instance - * @since 1.3.0 */ + @Since("1.3.0") override def load(sc: SparkContext, path: String): RandomForestModel = { val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName @@ -114,8 +112,8 @@ object RandomForestModel extends Loader[RandomForestModel] { * @param algo algorithm for the ensemble model, either Classification or Regression * @param trees tree ensembles * @param treeWeights tree ensemble weights - * @since 1.2.0 */ +@Since("1.2.0") @Experimental class GradientBoostedTreesModel( override val algo: Algo, @@ -130,8 +128,8 @@ class GradientBoostedTreesModel( * @param sc Spark context used to save model data. * @param path Path specifying the directory in which to save this model. * If the directory already exists, this method throws an exception. - * @since 1.3.0 */ + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this, GradientBoostedTreesModel.SaveLoadV1_0.thisClassName) @@ -143,8 +141,8 @@ class GradientBoostedTreesModel( * @param loss evaluation metric. * @return an array with index i having the losses or errors for the ensemble * containing the first i+1 trees - * @since 1.4.0 */ + @Since("1.4.0") def evaluateEachIteration( data: RDD[LabeledPoint], loss: Loss): Array[Double] = { @@ -186,8 +184,8 @@ class GradientBoostedTreesModel( } /** - * @since 1.3.0 */ +@Since("1.3.0") object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { /** @@ -199,8 +197,8 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { * @param loss: evaluation metric. * @return a RDD with each element being a zip of the prediction and error * corresponding to every sample. - * @since 1.4.0 */ + @Since("1.4.0") def computeInitialPredictionAndError( data: RDD[LabeledPoint], initTreeWeight: Double, @@ -223,8 +221,8 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { * @param loss: evaluation metric. * @return a RDD with each element being a zip of the prediction and error * corresponding to each sample. - * @since 1.4.0 */ + @Since("1.4.0") def updatePredictionError( data: RDD[LabeledPoint], predictionAndError: RDD[(Double, Double)], @@ -248,8 +246,8 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { * @param sc Spark context used for loading model files. * @param path Path specifying the directory to which the model was saved. * @return Model instance - * @since 1.3.0 */ + @Since("1.3.0") override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = { val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/package.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/package.scala index f520b3a1b7c72..bcaacc1b1f191 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/package.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/package.scala @@ -24,7 +24,6 @@ package org.apache.spark.mllib * - information loss calculation with entropy and Gini for classification and * variance for regression, * - both continuous and categorical features. - * @since 1.0.0 */ package object tree { } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 11ed23176fc12..4940974bf4f41 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -21,7 +21,7 @@ import scala.reflect.ClassTag import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.rdd.PartitionwiseSampledRDD @@ -64,8 +64,8 @@ object MLUtils { * feature dimensions. * @param minPartitions min number of partitions * @return labeled data stored as an RDD[LabeledPoint] - * @since 1.0.0 */ + @Since("1.0.0") def loadLibSVMFile( sc: SparkContext, path: String, @@ -115,9 +115,7 @@ object MLUtils { // Convenient methods for `loadLibSVMFile`. - /** - * @since 1.0.0 - */ + @Since("1.0.0") @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, @@ -130,17 +128,15 @@ object MLUtils { /** * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], with the default number of * partitions. - * @since 1.0.0 */ + @Since("1.0.0") def loadLibSVMFile( sc: SparkContext, path: String, numFeatures: Int): RDD[LabeledPoint] = loadLibSVMFile(sc, path, numFeatures, sc.defaultMinPartitions) - /** - * @since 1.0.0 - */ + @Since("1.0.0") @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, @@ -149,9 +145,7 @@ object MLUtils { numFeatures: Int): RDD[LabeledPoint] = loadLibSVMFile(sc, path, numFeatures) - /** - * @since 1.0.0 - */ + @Since("1.0.0") @deprecated("use method without multiclass argument, which no longer has effect", "1.1.0") def loadLibSVMFile( sc: SparkContext, @@ -162,8 +156,8 @@ object MLUtils { /** * Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint], with number of * features determined automatically and the default number of partitions. - * @since 1.0.0 */ + @Since("1.0.0") def loadLibSVMFile(sc: SparkContext, path: String): RDD[LabeledPoint] = loadLibSVMFile(sc, path, -1) @@ -193,15 +187,15 @@ object MLUtils { * @param path file or directory path in any Hadoop-supported file system URI * @param minPartitions min number of partitions * @return vectors stored as an RDD[Vector] - * @since 1.1.0 */ + @Since("1.1.0") def loadVectors(sc: SparkContext, path: String, minPartitions: Int): RDD[Vector] = sc.textFile(path, minPartitions).map(Vectors.parse) /** * Loads vectors saved using `RDD[Vector].saveAsTextFile` with the default number of partitions. - * @since 1.1.0 */ + @Since("1.1.0") def loadVectors(sc: SparkContext, path: String): RDD[Vector] = sc.textFile(path, sc.defaultMinPartitions).map(Vectors.parse) @@ -211,16 +205,16 @@ object MLUtils { * @param path file or directory path in any Hadoop-supported file system URI * @param minPartitions min number of partitions * @return labeled points stored as an RDD[LabeledPoint] - * @since 1.1.0 */ + @Since("1.1.0") def loadLabeledPoints(sc: SparkContext, path: String, minPartitions: Int): RDD[LabeledPoint] = sc.textFile(path, minPartitions).map(LabeledPoint.parse) /** * Loads labeled points saved using `RDD[LabeledPoint].saveAsTextFile` with the default number of * partitions. - * @since 1.1.0 */ + @Since("1.1.0") def loadLabeledPoints(sc: SparkContext, dir: String): RDD[LabeledPoint] = loadLabeledPoints(sc, dir, sc.defaultMinPartitions) @@ -236,8 +230,8 @@ object MLUtils { * * @deprecated Should use [[org.apache.spark.rdd.RDD#saveAsTextFile]] for saving and * [[org.apache.spark.mllib.util.MLUtils#loadLabeledPoints]] for loading. - * @since 1.0.0 */ + @Since("1.0.0") @deprecated("Should use MLUtils.loadLabeledPoints instead.", "1.0.1") def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { sc.textFile(dir).map { line => @@ -258,8 +252,8 @@ object MLUtils { * * @deprecated Should use [[org.apache.spark.rdd.RDD#saveAsTextFile]] for saving and * [[org.apache.spark.mllib.util.MLUtils#loadLabeledPoints]] for loading. - * @since 1.0.0 */ + @Since("1.0.0") @deprecated("Should use RDD[LabeledPoint].saveAsTextFile instead.", "1.0.1") def saveLabeledData(data: RDD[LabeledPoint], dir: String) { val dataStr = data.map(x => x.label + "," + x.features.toArray.mkString(" ")) @@ -271,8 +265,8 @@ object MLUtils { * Return a k element array of pairs of RDDs with the first element of each pair * containing the training data, a complement of the validation data and the second * element, the validation data, containing a unique 1/kth of the data. Where k=numFolds. - * @since 1.0.0 */ + @Since("1.0.0") @Experimental def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = { val numFoldsF = numFolds.toFloat @@ -287,8 +281,8 @@ object MLUtils { /** * Returns a new vector with `1.0` (bias) appended to the input vector. - * @since 1.0.0 */ + @Since("1.0.0") def appendBias(vector: Vector): Vector = { vector match { case dv: DenseVector => From e3355090d4030daffed5efb0959bf1d724c13c13 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 21 Aug 2015 14:30:00 -0700 Subject: [PATCH 036/802] [SPARK-10143] [SQL] Use parquet's block size (row group size) setting as the min split size if necessary. https://issues.apache.org/jira/browse/SPARK-10143 With this PR, we will set min split size to parquet's block size (row group size) set in the conf if the min split size is smaller. So, we can avoid have too many tasks and even useless tasks for reading parquet data. I tested it locally. The table I have has 343MB and it is in my local FS. Because I did not set any min/max split size, the default split size was 32MB and the map stage had 11 tasks. But there were only three tasks that actually read data. With my PR, there were only three tasks in the map stage. Here is the difference. Without this PR: ![image](https://cloud.githubusercontent.com/assets/2072857/9399179/8587dba6-4765-11e5-9189-7ebba52a2b6d.png) With this PR: ![image](https://cloud.githubusercontent.com/assets/2072857/9399185/a4735d74-4765-11e5-8848-1f1e361a6b4b.png) Even if the block size setting does match the actual block size of parquet file, I think it is still generally good to use parquet's block size setting if min split size is smaller than this block size. Tested it on a cluster using ``` val count = sqlContext.table("""store_sales""").groupBy().count().queryExecution.executedPlan(3).execute().count ``` Basically, it reads 0 column of table `store_sales`. My table has 1824 parquet files with size from 80MB to 280MB (1 to 3 row group sizes). Without this patch, in a 16 worker cluster, the job had 5023 tasks and spent 102s. With this patch, the job had 2893 tasks and spent 64s. It is still not as good as using one mapper per file (1824 tasks and 42s), but it is much better than our master. Author: Yin Huai Closes #8346 from yhuai/parquetMinSplit. --- .../datasources/parquet/ParquetRelation.scala | 41 ++++++++++++++++++- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 68169d48ac57c..bbf682aec0f9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -26,6 +26,7 @@ import scala.collection.mutable import scala.util.{Failure, Try} import com.google.common.base.Objects +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ @@ -281,12 +282,18 @@ private[sql] class ParquetRelation( val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp val followParquetFormatSpec = sqlContext.conf.followParquetFormatSpec + // Parquet row group size. We will use this value as the value for + // mapreduce.input.fileinputformat.split.minsize and mapred.min.split.size if the value + // of these flags are smaller than the parquet row group size. + val parquetBlockSize = ParquetOutputFormat.getLongBlockSize(broadcastedConf.value.value) + // Create the function to set variable Parquet confs at both driver and executor side. val initLocalJobFuncOpt = ParquetRelation.initializeLocalJobFunc( requiredColumns, filters, dataSchema, + parquetBlockSize, useMetadataCache, parquetFilterPushDown, assumeBinaryIsString, @@ -294,7 +301,8 @@ private[sql] class ParquetRelation( followParquetFormatSpec) _ // Create the function to set input paths at the driver side. - val setInputPaths = ParquetRelation.initializeDriverSideJobFunc(inputFiles) _ + val setInputPaths = + ParquetRelation.initializeDriverSideJobFunc(inputFiles, parquetBlockSize) _ Utils.withDummyCallSite(sqlContext.sparkContext) { new SqlNewHadoopRDD( @@ -482,11 +490,35 @@ private[sql] object ParquetRelation extends Logging { // internally. private[sql] val METASTORE_SCHEMA = "metastoreSchema" + /** + * If parquet's block size (row group size) setting is larger than the min split size, + * we use parquet's block size setting as the min split size. Otherwise, we will create + * tasks processing nothing (because a split does not cover the starting point of a + * parquet block). See https://issues.apache.org/jira/browse/SPARK-10143 for more information. + */ + private def overrideMinSplitSize(parquetBlockSize: Long, conf: Configuration): Unit = { + val minSplitSize = + math.max( + conf.getLong("mapred.min.split.size", 0L), + conf.getLong("mapreduce.input.fileinputformat.split.minsize", 0L)) + if (parquetBlockSize > minSplitSize) { + val message = + s"Parquet's block size (row group size) is larger than " + + s"mapred.min.split.size/mapreduce.input.fileinputformat.split.minsize. Setting " + + s"mapred.min.split.size and mapreduce.input.fileinputformat.split.minsize to " + + s"$parquetBlockSize." + logDebug(message) + conf.set("mapred.min.split.size", parquetBlockSize.toString) + conf.set("mapreduce.input.fileinputformat.split.minsize", parquetBlockSize.toString) + } + } + /** This closure sets various Parquet configurations at both driver side and executor side. */ private[parquet] def initializeLocalJobFunc( requiredColumns: Array[String], filters: Array[Filter], dataSchema: StructType, + parquetBlockSize: Long, useMetadataCache: Boolean, parquetFilterPushDown: Boolean, assumeBinaryIsString: Boolean, @@ -522,16 +554,21 @@ private[sql] object ParquetRelation extends Logging { conf.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, assumeBinaryIsString) conf.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, assumeInt96IsTimestamp) conf.setBoolean(SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key, followParquetFormatSpec) + + overrideMinSplitSize(parquetBlockSize, conf) } /** This closure sets input paths at the driver side. */ private[parquet] def initializeDriverSideJobFunc( - inputFiles: Array[FileStatus])(job: Job): Unit = { + inputFiles: Array[FileStatus], + parquetBlockSize: Long)(job: Job): Unit = { // We side the input paths at the driver side. logInfo(s"Reading Parquet file(s) from ${inputFiles.map(_.getPath).mkString(", ")}") if (inputFiles.nonEmpty) { FileInputFormat.setInputPaths(job, inputFiles.map(_.getPath): _*) } + + overrideMinSplitSize(parquetBlockSize, job.getConfiguration) } private[parquet] def readSchema( From f01c4220d2b791f470fa6596ffe11baa51517fbe Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 21 Aug 2015 16:28:00 -0700 Subject: [PATCH 037/802] [SPARK-10163] [ML] Allow single-category features for GBT models Removed categorical feature info validation since no longer needed This is needed to make the ML user guide examples work (in another current PR). CC: mengxr Author: Joseph K. Bradley Closes #8367 from jkbradley/gbt-single-cat. --- .../org/apache/spark/mllib/tree/configuration/Strategy.scala | 5 ----- 1 file changed, 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index a58f01ba8544e..b74e3f1f46523 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -158,11 +158,6 @@ class Strategy ( s" Valid values are integers >= 0.") require(maxBins >= 2, s"DecisionTree Strategy given invalid maxBins parameter: $maxBins." + s" Valid values are integers >= 2.") - categoricalFeaturesInfo.foreach { case (feature, arity) => - require(arity >= 2, - s"DecisionTree Strategy given invalid categoricalFeaturesInfo setting:" + - s" feature $feature has $arity categories. The number of categories should be >= 2.") - } require(minInstancesPerNode >= 1, s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") require(maxMemoryInMB <= 10240, From 630a994e6a9785d1704f8e7fb604f32f5dea24f8 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Fri, 21 Aug 2015 16:30:12 -0700 Subject: [PATCH 038/802] [SPARK-9893] User guide with Java test suite for VectorSlicer Add user guide for `VectorSlicer`, with Java test suite and Python version VectorSlicer. Note that Python version does not support selecting by names now. Author: Xusen Yin Closes #8267 from yinxusen/SPARK-9893. --- docs/ml-features.md | 133 ++++++++++++++++++ .../ml/feature/JavaVectorSlicerSuite.java | 85 +++++++++++ 2 files changed, 218 insertions(+) create mode 100644 mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java diff --git a/docs/ml-features.md b/docs/ml-features.md index 6309db97be4d0..642a4b4c53183 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1477,6 +1477,139 @@ print(output.select("features", "clicked").first()) +# Feature Selectors + +## VectorSlicer + +`VectorSlicer` is a transformer that takes a feature vector and outputs a new feature vector with a +sub-array of the original features. It is useful for extracting features from a vector column. + +`VectorSlicer` accepts a vector column with a specified indices, then outputs a new vector column +whose values are selected via those indices. There are two types of indices, + + 1. Integer indices that represents the indices into the vector, `setIndices()`; + + 2. String indices that represents the names of features into the vector, `setNames()`. + *This requires the vector column to have an `AttributeGroup` since the implementation matches on + the name field of an `Attribute`.* + +Specification by integer and string are both acceptable. Moreover, you can use integer index and +string name simultaneously. At least one feature must be selected. Duplicate features are not +allowed, so there can be no overlap between selected indices and names. Note that if names of +features are selected, an exception will be threw out when encountering with empty input attributes. + +The output vector will order features with the selected indices first (in the order given), +followed by the selected names (in the order given). + +**Examples** + +Suppose that we have a DataFrame with the column `userFeatures`: + +~~~ + userFeatures +------------------ + [0.0, 10.0, 0.5] +~~~ + +`userFeatures` is a vector column that contains three user features. Assuming that the first column +of `userFeatures` are all zeros, so we want to remove it and only the last two columns are selected. +The `VectorSlicer` selects the last two elements with `setIndices(1, 2)` then produces a new vector +column named `features`: + +~~~ + userFeatures | features +------------------|----------------------------- + [0.0, 10.0, 0.5] | [10.0, 0.5] +~~~ + +Suppose also that we have a potential input attributes for the `userFeatures`, i.e. +`["f1", "f2", "f3"]`, then we can use `setNames("f2", "f3")` to select them. + +~~~ + userFeatures | features +------------------|----------------------------- + [0.0, 10.0, 0.5] | [10.0, 0.5] + ["f1", "f2", "f3"] | ["f2", "f3"] +~~~ + +
+
+ +[`VectorSlicer`](api/scala/index.html#org.apache.spark.ml.feature.VectorSlicer) takes an input +column name with specified indices or names and an output column name. + +{% highlight scala %} +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} +import org.apache.spark.ml.feature.VectorSlicer +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + +val data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(-2.0, 2.3, 0.0) +) + +val defaultAttr = NumericAttribute.defaultAttr +val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName) +val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) + +val dataRDD = sc.parallelize(data).map(Row.apply) +val dataset = sqlContext.createDataFrame(dataRDD, StructType(attrGroup.toStructField())) + +val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") + +slicer.setIndices(1).setNames("f3") +// or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3")) + +val output = slicer.transform(dataset) +println(output.select("userFeatures", "features").first()) +{% endhighlight %} +
+ +
+ +[`VectorSlicer`](api/java/org/apache/spark/ml/feature/VectorSlicer.html) takes an input column name +with specified indices or names and an output column name. + +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; + +Attribute[] attrs = new Attribute[]{ + NumericAttribute.defaultAttr().withName("f1"), + NumericAttribute.defaultAttr().withName("f2"), + NumericAttribute.defaultAttr().withName("f3") +}; +AttributeGroup group = new AttributeGroup("userFeatures", attrs); + +JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), + RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) +)); + +DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); + +VectorSlicer vectorSlicer = new VectorSlicer() + .setInputCol("userFeatures").setOutputCol("features"); + +vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); +// or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) + +DataFrame output = vectorSlicer.transform(dataset); + +System.out.println(output.select("userFeatures", "features").first()); +{% endhighlight %} +
+
+ ## RFormula `RFormula` selects columns specified by an [R model formula](https://stat.ethz.ch/R-manual/R-devel/library/stats/html/formula.html). It produces a vector column of features and a double column of labels. Like when formulas are used in R for linear regression, string input columns will be one-hot encoded, and numeric columns will be cast to doubles. If not already present in the DataFrame, the output label column will be created from the specified response variable in the formula. diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java new file mode 100644 index 0000000000000..56988b9fb29cb --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import com.google.common.collect.Lists; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.attribute.Attribute; +import org.apache.spark.ml.attribute.AttributeGroup; +import org.apache.spark.ml.attribute.NumericAttribute; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.StructType; + + +public class JavaVectorSlicerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaVectorSlicerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void vectorSlice() { + Attribute[] attrs = new Attribute[]{ + NumericAttribute.defaultAttr().withName("f1"), + NumericAttribute.defaultAttr().withName("f2"), + NumericAttribute.defaultAttr().withName("f3") + }; + AttributeGroup group = new AttributeGroup("userFeatures", attrs); + + JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), + RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) + )); + + DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); + + VectorSlicer vectorSlicer = new VectorSlicer() + .setInputCol("userFeatures").setOutputCol("features"); + + vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); + + DataFrame output = vectorSlicer.transform(dataset); + + for (Row r : output.select("userFeatures", "features").take(2)) { + Vector features = r.getAs(1); + Assert.assertEquals(features.size(), 2); + } + } +} From 46fcb9e0dbb2b28110f68a3d9f6c0c47bfd197b1 Mon Sep 17 00:00:00 2001 From: Keiji Yoshida Date: Sat, 22 Aug 2015 02:38:10 -0700 Subject: [PATCH 039/802] Update programming-guide.md Update `lineLengths.persist();` to `lineLengths.persist(StorageLevel.MEMORY_ONLY());` because `JavaRDD#persist` needs a parameter of `StorageLevel`. Author: Keiji Yoshida Closes #8372 from yosssi/patch-1. --- docs/programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 982c5eabe652b..4cf83bb392636 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -549,7 +549,7 @@ returning only its answer to the driver program. If we also wanted to use `lineLengths` again later, we could add: {% highlight java %} -lineLengths.persist(); +lineLengths.persist(StorageLevel.MEMORY_ONLY()); {% endhighlight %} before the `reduce`, which would cause `lineLengths` to be saved in memory after the first time it is computed. From 90cb9f05655a25b95b8f9fe81da14e5b9c8bcf44 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Sat, 22 Aug 2015 10:16:35 -0700 Subject: [PATCH 040/802] [SPARK-9401] [SQL] Fully implement code generation for ConcatWs This PR adds full codegen support for ConcatWs, is a substitute of #7782 JIRA: https://issues.apache.org/jira/browse/SPARK-9401 cc davies Author: Yijie Shen Closes #8353 from yjshen/concatws. --- .../expressions/stringExpressions.scala | 42 +++++++++++++++++-- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index b60d318534a41..48d02bb534501 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -72,7 +72,7 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas * Returns null if the separator is null. Otherwise, concat_ws skips all null values. */ case class ConcatWs(children: Seq[Expression]) - extends Expression with ImplicitCastInputTypes with CodegenFallback { + extends Expression with ImplicitCastInputTypes { require(children.nonEmpty, s"$prettyName requires at least one argument.") @@ -114,8 +114,44 @@ case class ConcatWs(children: Seq[Expression]) boolean ${ev.isNull} = ${ev.primitive} == null; """ } else { - // Contains a mix of strings and arrays. Fall back to interpreted mode for now. - super.genCode(ctx, ev) + val array = ctx.freshName("array") + val varargNum = ctx.freshName("varargNum") + val idxInVararg = ctx.freshName("idxInVararg") + + val evals = children.map(_.gen(ctx)) + val (varargCount, varargBuild) = children.tail.zip(evals.tail).map { case (child, eval) => + child.dataType match { + case StringType => + ("", // we count all the StringType arguments num at once below. + s"$array[$idxInVararg ++] = ${eval.isNull} ? (UTF8String) null : ${eval.primitive};") + case _: ArrayType => + val size = ctx.freshName("n") + (s""" + if (!${eval.isNull}) { + $varargNum += ${eval.primitive}.numElements(); + } + """, + s""" + if (!${eval.isNull}) { + final int $size = ${eval.primitive}.numElements(); + for (int j = 0; j < $size; j ++) { + $array[$idxInVararg ++] = ${ctx.getValue(eval.primitive, StringType, "j")}; + } + } + """) + } + }.unzip + + evals.map(_.code).mkString("\n") + + s""" + int $varargNum = ${children.count(_.dataType == StringType) - 1}; + int $idxInVararg = 0; + ${varargCount.mkString("\n")} + UTF8String[] $array = new UTF8String[$varargNum]; + ${varargBuild.mkString("\n")} + UTF8String ${ev.primitive} = UTF8String.concatWs(${evals.head.primitive}, $array); + boolean ${ev.isNull} = ${ev.primitive} == null; + """ } } } From 623c675fde7a3a39957a62c7af26a54f4b01f8ce Mon Sep 17 00:00:00 2001 From: Keiji Yoshida Date: Sun, 23 Aug 2015 11:04:29 +0100 Subject: [PATCH 041/802] Update streaming-programming-guide.md Update `See the Scala example` to `See the Java example`. Author: Keiji Yoshida Closes #8376 from yosssi/patch-1. --- docs/streaming-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index c59d936b43c88..118ced298f4b0 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1702,7 +1702,7 @@ context.awaitTermination(); If the `checkpointDirectory` exists, then the context will be recreated from the checkpoint data. If the directory does not exist (i.e., running for the first time), then the function `contextFactory` will be called to create a new -context and set up the DStreams. See the Scala example +context and set up the DStreams. See the Java example [JavaRecoverableNetworkWordCount]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java). This example appends the word counts of network data into a file. From c6df5f66d9a8b9760f2cd46fcd930f977650c9c5 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sun, 23 Aug 2015 17:41:49 -0700 Subject: [PATCH 042/802] [SPARK-10148] [STREAMING] Display active and inactive receiver numbers in Streaming page Added the active and inactive receiver numbers in the summary section of Streaming page. screen shot 2015-08-21 at 2 08 54 pm Author: zsxwing Closes #8351 from zsxwing/receiver-number. --- .../spark/streaming/ui/StreamingJobProgressListener.scala | 8 ++++++++ .../org/apache/spark/streaming/ui/StreamingPage.scala | 6 ++++++ 2 files changed, 14 insertions(+) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index b77c555c68b8b..78aeb004e18b1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -148,6 +148,14 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) receiverInfos.size } + def numActiveReceivers: Int = synchronized { + receiverInfos.count(_._2.active) + } + + def numInactiveReceivers: Int = { + ssc.graph.getReceiverInputStreams().size - numActiveReceivers + } + def numTotalCompletedBatches: Long = synchronized { totalCompletedBatches } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 87af902428ec8..96d943e75d272 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -303,6 +303,7 @@ private[ui] class StreamingPage(parent: StreamingTab) val numCompletedBatches = listener.retainedCompletedBatches.size val numActiveBatches = batchTimes.length - numCompletedBatches + val numReceivers = listener.numInactiveReceivers + listener.numActiveReceivers val table = // scalastyle:off @@ -330,6 +331,11 @@ private[ui] class StreamingPage(parent: StreamingTab) } } + { + if (numReceivers > 0) { +
Receivers: {listener.numActiveReceivers} / {numReceivers} active
+ } + }
Avg: {eventRateForAllStreams.formattedAvg} events/sec
From b963c19a803c5a26c9b65655d40ca6621acf8bd4 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sun, 23 Aug 2015 18:34:07 -0700 Subject: [PATCH 043/802] [SPARK-10164] [MLLIB] Fixed GMM distributed decomposition bug GaussianMixture now distributes matrix decompositions for certain problem sizes. Distributed computation actually fails, but this was not tested in unit tests. This PR adds a unit test which checks this. It failed previously but works with this fix. CC: mengxr Author: Joseph K. Bradley Closes #8370 from jkbradley/gmm-fix. --- .../mllib/clustering/GaussianMixture.scala | 22 +++++++++++++------ .../clustering/GaussianMixtureSuite.scala | 22 +++++++++++++++++-- 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index fcc9dfecac54f..daa947e81d44d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -169,9 +169,7 @@ class GaussianMixture private ( // Get length of the input vectors val d = breezeData.first().length - // Heuristic to distribute the computation of the [[MultivariateGaussian]]s, approximately when - // d > 25 except for when k is very small - val distributeGaussians = ((k - 1.0) / k) * d > 25 + val shouldDistributeGaussians = GaussianMixture.shouldDistributeGaussians(k, d) // Determine initial weights and corresponding Gaussians. // If the user supplied an initial GMM, we use those values, otherwise @@ -205,15 +203,15 @@ class GaussianMixture private ( // (often referred to as the "M" step in literature) val sumWeights = sums.weights.sum - if (distributeGaussians) { + if (shouldDistributeGaussians) { val numPartitions = math.min(k, 1024) val tuples = Seq.tabulate(k)(i => (sums.means(i), sums.sigmas(i), sums.weights(i))) val (ws, gs) = sc.parallelize(tuples, numPartitions).map { case (mean, sigma, weight) => updateWeightsAndGaussians(mean, sigma, weight, sumWeights) - }.collect.unzip - Array.copy(ws, 0, weights, 0, ws.length) - Array.copy(gs, 0, gaussians, 0, gs.length) + }.collect().unzip + Array.copy(ws.toArray, 0, weights, 0, ws.length) + Array.copy(gs.toArray, 0, gaussians, 0, gs.length) } else { var i = 0 while (i < k) { @@ -271,6 +269,16 @@ class GaussianMixture private ( } } +private[clustering] object GaussianMixture { + /** + * Heuristic to distribute the computation of the [[MultivariateGaussian]]s, approximately when + * d > 25 except for when k is very small. + * @param k Number of topics + * @param d Number of features + */ + def shouldDistributeGaussians(k: Int, d: Int): Boolean = ((k - 1.0) / k) * d > 25 +} + // companion class to provide zero constructor for ExpectationSum private object ExpectationSum { def zero(k: Int, d: Int): ExpectationSum = { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala index b636d02f786e6..a72723eb00daf 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.clustering import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{Vectors, Matrices} +import org.apache.spark.mllib.linalg.{Vector, Vectors, Matrices} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -76,6 +76,20 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { assert(gmm.gaussians(1).sigma ~== Esigma(1) absTol 1E-3) } + test("two clusters with distributed decompositions") { + val data = sc.parallelize(GaussianTestData.data2, 2) + + val k = 5 + val d = data.first().size + assert(GaussianMixture.shouldDistributeGaussians(k, d)) + + val gmm = new GaussianMixture() + .setK(k) + .run(data) + + assert(gmm.k === k) + } + test("single cluster with sparse data") { val data = sc.parallelize(Array( Vectors.sparse(3, Array(0, 2), Array(4.0, 2.0)), @@ -116,7 +130,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { val sparseGMM = new GaussianMixture() .setK(2) .setInitialModel(initialGmm) - .run(data) + .run(sparseData) assert(sparseGMM.weights(0) ~== Ew(0) absTol 1E-3) assert(sparseGMM.weights(1) ~== Ew(1) absTol 1E-3) @@ -168,5 +182,9 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734) ) + val data2: Array[Vector] = Array.tabulate(25){ i: Int => + Vectors.dense(Array.tabulate(50)(i + _.toDouble)) + } + } } From 053d94fcf32268369b5a40837271f15d6af41aa4 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 23 Aug 2015 19:24:32 -0700 Subject: [PATCH 044/802] [SPARK-10142] [STREAMING] Made python checkpoint recovery handle non-local checkpoint paths and existing SparkContexts The current code only checks checkpoint files in local filesystem, and always tries to create a new Python SparkContext (even if one already exists). The solution is to do the following: 1. Use the same code path as Java to check whether a valid checkpoint exists 2. Create a new Python SparkContext only if there no active one. There is not test for the path as its hard to test with distributed filesystem paths in a local unit test. I am going to test it with a distributed file system manually to verify that this patch works. Author: Tathagata Das Closes #8366 from tdas/SPARK-10142 and squashes the following commits: 3afa666 [Tathagata Das] Added tests 2dd4ae5 [Tathagata Das] Added the check to not create a context if one already exists 9bf151b [Tathagata Das] Made python checkpoint recovery use java to find the checkpoint files --- python/pyspark/streaming/context.py | 22 ++++++---- python/pyspark/streaming/tests.py | 43 ++++++++++++++++--- .../apache/spark/streaming/Checkpoint.scala | 9 ++++ 3 files changed, 58 insertions(+), 16 deletions(-) diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index e3ba70e4e5e88..4069d7a149986 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -150,26 +150,30 @@ def getOrCreate(cls, checkpointPath, setupFunc): @param checkpointPath: Checkpoint directory used in an earlier streaming program @param setupFunc: Function to create a new context and setup DStreams """ - # TODO: support checkpoint in HDFS - if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath): + cls._ensure_initialized() + gw = SparkContext._gateway + + # Check whether valid checkpoint information exists in the given path + if gw.jvm.CheckpointReader.read(checkpointPath).isEmpty(): ssc = setupFunc() ssc.checkpoint(checkpointPath) return ssc - cls._ensure_initialized() - gw = SparkContext._gateway - try: jssc = gw.jvm.JavaStreamingContext(checkpointPath) except Exception: print("failed to load StreamingContext from checkpoint", file=sys.stderr) raise - jsc = jssc.sparkContext() - conf = SparkConf(_jconf=jsc.getConf()) - sc = SparkContext(conf=conf, gateway=gw, jsc=jsc) + # If there is already an active instance of Python SparkContext use it, or create a new one + if not SparkContext._active_spark_context: + jsc = jssc.sparkContext() + conf = SparkConf(_jconf=jsc.getConf()) + SparkContext(conf=conf, gateway=gw, jsc=jsc) + + sc = SparkContext._active_spark_context + # update ctx in serializer - SparkContext._active_spark_context = sc cls._transformerSerializer.ctx = sc return StreamingContext(sc, None, jssc) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 214d5be439003..510a4f2b3e472 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -603,6 +603,10 @@ def tearDownClass(): def tearDown(self): if self.ssc is not None: self.ssc.stop(True) + if self.sc is not None: + self.sc.stop() + if self.cpd is not None: + shutil.rmtree(self.cpd) def test_get_or_create_and_get_active_or_create(self): inputd = tempfile.mkdtemp() @@ -622,8 +626,12 @@ def setup(): self.setupCalled = True return ssc - cpd = tempfile.mkdtemp("test_streaming_cps") - self.ssc = StreamingContext.getOrCreate(cpd, setup) + # Verify that getOrCreate() calls setup() in absence of checkpoint files + self.cpd = tempfile.mkdtemp("test_streaming_cps") + self.setupCalled = False + self.ssc = StreamingContext.getOrCreate(self.cpd, setup) + self.assertFalse(self.setupCalled) + self.ssc.start() def check_output(n): @@ -660,31 +668,52 @@ def check_output(n): self.ssc.stop(True, True) time.sleep(1) self.setupCalled = False - self.ssc = StreamingContext.getOrCreate(cpd, setup) + self.ssc = StreamingContext.getOrCreate(self.cpd, setup) self.assertFalse(self.setupCalled) self.ssc.start() check_output(3) + # Verify that getOrCreate() uses existing SparkContext + self.ssc.stop(True, True) + time.sleep(1) + sc = SparkContext(SparkConf()) + self.setupCalled = False + self.ssc = StreamingContext.getOrCreate(self.cpd, setup) + self.assertFalse(self.setupCalled) + self.assertTrue(self.ssc.sparkContext == sc) + # Verify the getActiveOrCreate() recovers from checkpoint files self.ssc.stop(True, True) time.sleep(1) self.setupCalled = False - self.ssc = StreamingContext.getActiveOrCreate(cpd, setup) + self.ssc = StreamingContext.getActiveOrCreate(self.cpd, setup) self.assertFalse(self.setupCalled) self.ssc.start() check_output(4) # Verify that getActiveOrCreate() returns active context self.setupCalled = False - self.assertEquals(StreamingContext.getActiveOrCreate(cpd, setup), self.ssc) + self.assertEquals(StreamingContext.getActiveOrCreate(self.cpd, setup), self.ssc) self.assertFalse(self.setupCalled) + # Verify that getActiveOrCreate() uses existing SparkContext + self.ssc.stop(True, True) + time.sleep(1) + self.sc = SparkContext(SparkConf()) + self.setupCalled = False + self.ssc = StreamingContext.getActiveOrCreate(self.cpd, setup) + self.assertFalse(self.setupCalled) + self.assertTrue(self.ssc.sparkContext == sc) + # Verify that getActiveOrCreate() calls setup() in absence of checkpoint files self.ssc.stop(True, True) - shutil.rmtree(cpd) # delete checkpoint directory + shutil.rmtree(self.cpd) # delete checkpoint directory + time.sleep(1) self.setupCalled = False - self.ssc = StreamingContext.getActiveOrCreate(cpd, setup) + self.ssc = StreamingContext.getActiveOrCreate(self.cpd, setup) self.assertTrue(self.setupCalled) + + # Stop everything self.ssc.stop(True, True) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 6f6b449accc3c..cd5d960369c05 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -286,6 +286,15 @@ class CheckpointWriter( private[streaming] object CheckpointReader extends Logging { + /** + * Read checkpoint files present in the given checkpoint directory. If there are no checkpoint + * files, then return None, else try to return the latest valid checkpoint object. If no + * checkpoint files could be read correctly, then return None. + */ + def read(checkpointDir: String): Option[Checkpoint] = { + read(checkpointDir, new SparkConf(), SparkHadoopUtil.get.conf, ignoreReadError = true) + } + /** * Read checkpoint files present in the given checkpoint directory. If there are no checkpoint * files, then return None, else try to return the latest valid checkpoint object. If no From 4e0395ddb764d092b5b38447af49e196e590e0f0 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 24 Aug 2015 12:38:01 -0700 Subject: [PATCH 045/802] [SPARK-10168] [STREAMING] Fix the issue that maven publishes wrong artifact jars This PR removed the `outputFile` configuration from pom.xml and updated `tests.py` to search jars for both sbt build and maven build. I ran ` mvn -Pkinesis-asl -DskipTests clean install` locally, and verified the jars in my local repository were correct. I also checked Python tests for maven build, and it passed all tests. Author: zsxwing Closes #8373 from zsxwing/SPARK-10168 and squashes the following commits: e0b5818 [zsxwing] Fix the sbt build c697627 [zsxwing] Add the jar pathes to the exception message be1d8a5 [zsxwing] Fix the issue that maven publishes wrong artifact jars --- external/flume-assembly/pom.xml | 1 - external/kafka-assembly/pom.xml | 1 - external/mqtt-assembly/pom.xml | 1 - extras/kinesis-asl-assembly/pom.xml | 1 - python/pyspark/streaming/tests.py | 47 ++++++++++++++++------------- 5 files changed, 26 insertions(+), 25 deletions(-) diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index e05e4318969ce..561ed4babe5d0 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -115,7 +115,6 @@ maven-shade-plugin false - ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-flume-assembly-${project.version}.jar *:* diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml index 36342f37bb2ea..6f4e2a89e9af7 100644 --- a/external/kafka-assembly/pom.xml +++ b/external/kafka-assembly/pom.xml @@ -142,7 +142,6 @@ maven-shade-plugin false - ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-kafka-assembly-${project.version}.jar *:* diff --git a/external/mqtt-assembly/pom.xml b/external/mqtt-assembly/pom.xml index f3e3f93e7ed50..8412600633734 100644 --- a/external/mqtt-assembly/pom.xml +++ b/external/mqtt-assembly/pom.xml @@ -132,7 +132,6 @@ maven-shade-plugin false - ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-mqtt-assembly-${project.version}.jar *:* diff --git a/extras/kinesis-asl-assembly/pom.xml b/extras/kinesis-asl-assembly/pom.xml index 3ca538608f694..51af3e6f2225f 100644 --- a/extras/kinesis-asl-assembly/pom.xml +++ b/extras/kinesis-asl-assembly/pom.xml @@ -137,7 +137,6 @@ maven-shade-plugin false - ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-kinesis-asl-assembly-${project.version}.jar *:* diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 510a4f2b3e472..cfea95b0dec71 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1162,11 +1162,20 @@ def get_output(_, rdd): kinesisTestUtils.deleteDynamoDBTable(kinesisAppName) +# Search jar in the project dir using the jar name_prefix for both sbt build and maven build because +# the artifact jars are in different directories. +def search_jar(dir, name_prefix): + # We should ignore the following jars + ignored_jar_suffixes = ("javadoc.jar", "sources.jar", "test-sources.jar", "tests.jar") + jars = (glob.glob(os.path.join(dir, "target/scala-*/" + name_prefix + "-*.jar")) + # sbt build + glob.glob(os.path.join(dir, "target/" + name_prefix + "_*.jar"))) # maven build + return [jar for jar in jars if not jar.endswith(ignored_jar_suffixes)] + + def search_kafka_assembly_jar(): SPARK_HOME = os.environ["SPARK_HOME"] kafka_assembly_dir = os.path.join(SPARK_HOME, "external/kafka-assembly") - jars = glob.glob( - os.path.join(kafka_assembly_dir, "target/scala-*/spark-streaming-kafka-assembly-*.jar")) + jars = search_jar(kafka_assembly_dir, "spark-streaming-kafka-assembly") if not jars: raise Exception( ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + @@ -1174,8 +1183,8 @@ def search_kafka_assembly_jar(): "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or " "'build/mvn package' before running this test.") elif len(jars) > 1: - raise Exception(("Found multiple Spark Streaming Kafka assembly JARs in %s; please " - "remove all but one") % kafka_assembly_dir) + raise Exception(("Found multiple Spark Streaming Kafka assembly JARs: %s; please " + "remove all but one") % (", ".join(jars))) else: return jars[0] @@ -1183,8 +1192,7 @@ def search_kafka_assembly_jar(): def search_flume_assembly_jar(): SPARK_HOME = os.environ["SPARK_HOME"] flume_assembly_dir = os.path.join(SPARK_HOME, "external/flume-assembly") - jars = glob.glob( - os.path.join(flume_assembly_dir, "target/scala-*/spark-streaming-flume-assembly-*.jar")) + jars = search_jar(flume_assembly_dir, "spark-streaming-flume-assembly") if not jars: raise Exception( ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + @@ -1192,8 +1200,8 @@ def search_flume_assembly_jar(): "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or " "'build/mvn package' before running this test.") elif len(jars) > 1: - raise Exception(("Found multiple Spark Streaming Flume assembly JARs in %s; please " - "remove all but one") % flume_assembly_dir) + raise Exception(("Found multiple Spark Streaming Flume assembly JARs: %s; please " + "remove all but one") % (", ".join(jars))) else: return jars[0] @@ -1201,8 +1209,7 @@ def search_flume_assembly_jar(): def search_mqtt_assembly_jar(): SPARK_HOME = os.environ["SPARK_HOME"] mqtt_assembly_dir = os.path.join(SPARK_HOME, "external/mqtt-assembly") - jars = glob.glob( - os.path.join(mqtt_assembly_dir, "target/scala-*/spark-streaming-mqtt-assembly-*.jar")) + jars = search_jar(mqtt_assembly_dir, "spark-streaming-mqtt-assembly") if not jars: raise Exception( ("Failed to find Spark Streaming MQTT assembly jar in %s. " % mqtt_assembly_dir) + @@ -1210,8 +1217,8 @@ def search_mqtt_assembly_jar(): "'build/sbt assembly/assembly streaming-mqtt-assembly/assembly' or " "'build/mvn package' before running this test") elif len(jars) > 1: - raise Exception(("Found multiple Spark Streaming MQTT assembly JARs in %s; please " - "remove all but one") % mqtt_assembly_dir) + raise Exception(("Found multiple Spark Streaming MQTT assembly JARs: %s; please " + "remove all but one") % (", ".join(jars))) else: return jars[0] @@ -1227,8 +1234,8 @@ def search_mqtt_test_jar(): "You need to build Spark with " "'build/sbt assembly/assembly streaming-mqtt/test:assembly'") elif len(jars) > 1: - raise Exception(("Found multiple Spark Streaming MQTT test JARs in %s; please " - "remove all but one") % mqtt_test_dir) + raise Exception(("Found multiple Spark Streaming MQTT test JARs: %s; please " + "remove all but one") % (", ".join(jars))) else: return jars[0] @@ -1236,14 +1243,12 @@ def search_mqtt_test_jar(): def search_kinesis_asl_assembly_jar(): SPARK_HOME = os.environ["SPARK_HOME"] kinesis_asl_assembly_dir = os.path.join(SPARK_HOME, "extras/kinesis-asl-assembly") - jars = glob.glob( - os.path.join(kinesis_asl_assembly_dir, - "target/scala-*/spark-streaming-kinesis-asl-assembly-*.jar")) + jars = search_jar(kinesis_asl_assembly_dir, "spark-streaming-kinesis-asl-assembly") if not jars: return None elif len(jars) > 1: - raise Exception(("Found multiple Spark Streaming Kinesis ASL assembly JARs in %s; please " - "remove all but one") % kinesis_asl_assembly_dir) + raise Exception(("Found multiple Spark Streaming Kinesis ASL assembly JARs: %s; please " + "remove all but one") % (", ".join(jars))) else: return jars[0] @@ -1269,8 +1274,8 @@ def search_kinesis_asl_assembly_jar(): mqtt_test_jar, kinesis_asl_assembly_jar) os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars - testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, - CheckpointTests, KafkaStreamTests, FlumeStreamTests, FlumePollingStreamTests] + testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, + KafkaStreamTests, FlumeStreamTests, FlumePollingStreamTests, MQTTStreamTests] if kinesis_jar_present is True: testcases.append(KinesisStreamTests) From 7478c8b66d6a2b1179f20c38b49e27e37b0caec3 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 24 Aug 2015 12:40:09 -0700 Subject: [PATCH 046/802] [SPARK-9791] [PACKAGE] Change private class to private class to prevent unnecessary classes from showing up in the docs In addition, some random cleanup of import ordering Author: Tathagata Das Closes #8387 from tdas/SPARK-9791 and squashes the following commits: 67f3ee9 [Tathagata Das] Change private class to private[package] class to prevent them from showing up in the docs --- .../spark/streaming/flume/FlumeUtils.scala | 2 +- .../apache/spark/streaming/kafka/Broker.scala | 6 ++-- .../streaming/kafka/KafkaTestUtils.scala | 10 +++--- .../spark/streaming/kafka/KafkaUtils.scala | 36 +++++-------------- .../spark/streaming/kafka/OffsetRange.scala | 8 ----- .../spark/streaming/mqtt/MQTTUtils.scala | 6 ++-- .../spark/streaming/mqtt/MQTTTestUtils.scala | 2 +- .../streaming/kinesis/KinesisTestUtils.scala | 2 +- .../spark/streaming/util/WriteAheadLog.java | 2 ++ .../util/WriteAheadLogRecordHandle.java | 2 ++ .../receiver/ReceivedBlockHandler.scala | 2 +- .../streaming/scheduler/ReceiverTracker.scala | 2 +- .../apache/spark/streaming/ui/BatchPage.scala | 2 +- 13 files changed, 28 insertions(+), 54 deletions(-) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index 095bfb0c73a9a..a65a9b921aafa 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -247,7 +247,7 @@ object FlumeUtils { * This is a helper class that wraps the methods in FlumeUtils into more Python-friendly class and * function so that it can be easily instantiated and called from Python's FlumeUtils. */ -private class FlumeUtilsPythonHelper { +private[flume] class FlumeUtilsPythonHelper { def createStream( jssc: JavaStreamingContext, diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala index 5a74febb4bd46..9159051ba06e4 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala @@ -20,11 +20,9 @@ package org.apache.spark.streaming.kafka import org.apache.spark.annotation.Experimental /** - * :: Experimental :: - * Represent the host and port info for a Kafka broker. - * Differs from the Kafka project's internal kafka.cluster.Broker, which contains a server ID + * Represents the host and port info for a Kafka broker. + * Differs from the Kafka project's internal kafka.cluster.Broker, which contains a server ID. */ -@Experimental final class Broker private( /** Broker's hostname */ val host: String, diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala index b608b75952721..79a9db4291bef 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala @@ -20,9 +20,8 @@ package org.apache.spark.streaming.kafka import java.io.File import java.lang.{Integer => JInt} import java.net.InetSocketAddress -import java.util.{Map => JMap} -import java.util.Properties import java.util.concurrent.TimeoutException +import java.util.{Map => JMap, Properties} import scala.annotation.tailrec import scala.language.postfixOps @@ -30,17 +29,16 @@ import scala.util.control.NonFatal import kafka.admin.AdminUtils import kafka.api.Request -import kafka.common.TopicAndPartition import kafka.producer.{KeyedMessage, Producer, ProducerConfig} import kafka.serializer.StringEncoder import kafka.server.{KafkaConfig, KafkaServer} import kafka.utils.{ZKStringSerializer, ZkUtils} -import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} import org.I0Itec.zkclient.ZkClient +import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} -import org.apache.spark.{Logging, SparkConf} import org.apache.spark.streaming.Time import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf} /** * This is a helper class for Kafka test suites. This has the functionality to set up @@ -48,7 +46,7 @@ import org.apache.spark.util.Utils * * The reason to put Kafka test utility class in src is to test Python related Kafka APIs. */ -private class KafkaTestUtils extends Logging { +private[kafka] class KafkaTestUtils extends Logging { // Zookeeper related configurations private val zkHost = "localhost" diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index f3b01bd60b178..388dbb8184106 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -17,29 +17,25 @@ package org.apache.spark.streaming.kafka -import java.lang.{Integer => JInt} -import java.lang.{Long => JLong} -import java.util.{Map => JMap} -import java.util.{Set => JSet} -import java.util.{List => JList} +import java.lang.{Integer => JInt, Long => JLong} +import java.util.{List => JList, Map => JMap, Set => JSet} -import scala.reflect.ClassTag import scala.collection.JavaConversions._ +import scala.reflect.ClassTag import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata -import kafka.serializer.{DefaultDecoder, Decoder, StringDecoder} +import kafka.serializer.{Decoder, DefaultDecoder, StringDecoder} import org.apache.spark.api.java.function.{Function => JFunction} -import org.apache.spark.streaming.util.WriteAheadLogUtils -import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaPairInputDStream, JavaInputDStream, JavaPairReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.api.java.{JavaInputDStream, JavaPairInputDStream, JavaPairReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream} -import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} +import org.apache.spark.streaming.util.WriteAheadLogUtils +import org.apache.spark.{SparkContext, SparkException} object KafkaUtils { /** @@ -196,7 +192,6 @@ object KafkaUtils { * @param offsetRanges Each OffsetRange in the batch corresponds to a * range of offsets for a given Kafka topic/partition */ - @Experimental def createRDD[ K: ClassTag, V: ClassTag, @@ -214,7 +209,6 @@ object KafkaUtils { } /** - * :: Experimental :: * Create a RDD from Kafka using offset ranges for each topic and partition. This allows you * specify the Kafka leader to connect to (to optimize fetching) and access the message as well * as the metadata. @@ -230,7 +224,6 @@ object KafkaUtils { * in which case leaders will be looked up on the driver. * @param messageHandler Function for translating each message and metadata into the desired type */ - @Experimental def createRDD[ K: ClassTag, V: ClassTag, @@ -268,7 +261,6 @@ object KafkaUtils { * @param offsetRanges Each OffsetRange in the batch corresponds to a * range of offsets for a given Kafka topic/partition */ - @Experimental def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V]]( jsc: JavaSparkContext, keyClass: Class[K], @@ -287,7 +279,6 @@ object KafkaUtils { } /** - * :: Experimental :: * Create a RDD from Kafka using offset ranges for each topic and partition. This allows you * specify the Kafka leader to connect to (to optimize fetching) and access the message as well * as the metadata. @@ -303,7 +294,6 @@ object KafkaUtils { * in which case leaders will be looked up on the driver. * @param messageHandler Function for translating each message and metadata into the desired type */ - @Experimental def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V], R]( jsc: JavaSparkContext, keyClass: Class[K], @@ -327,7 +317,6 @@ object KafkaUtils { } /** - * :: Experimental :: * Create an input stream that directly pulls messages from Kafka Brokers * without using any receiver. This stream can guarantee that each message * from Kafka is included in transformations exactly once (see points below). @@ -357,7 +346,6 @@ object KafkaUtils { * starting point of the stream * @param messageHandler Function for translating each message and metadata into the desired type */ - @Experimental def createDirectStream[ K: ClassTag, V: ClassTag, @@ -375,7 +363,6 @@ object KafkaUtils { } /** - * :: Experimental :: * Create an input stream that directly pulls messages from Kafka Brokers * without using any receiver. This stream can guarantee that each message * from Kafka is included in transformations exactly once (see points below). @@ -405,7 +392,6 @@ object KafkaUtils { * to determine where the stream starts (defaults to "largest") * @param topics Names of the topics to consume */ - @Experimental def createDirectStream[ K: ClassTag, V: ClassTag, @@ -437,7 +423,6 @@ object KafkaUtils { } /** - * :: Experimental :: * Create an input stream that directly pulls messages from Kafka Brokers * without using any receiver. This stream can guarantee that each message * from Kafka is included in transformations exactly once (see points below). @@ -472,7 +457,6 @@ object KafkaUtils { * starting point of the stream * @param messageHandler Function for translating each message and metadata into the desired type */ - @Experimental def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V], R]( jssc: JavaStreamingContext, keyClass: Class[K], @@ -499,7 +483,6 @@ object KafkaUtils { } /** - * :: Experimental :: * Create an input stream that directly pulls messages from Kafka Brokers * without using any receiver. This stream can guarantee that each message * from Kafka is included in transformations exactly once (see points below). @@ -533,7 +516,6 @@ object KafkaUtils { * to determine where the stream starts (defaults to "largest") * @param topics Names of the topics to consume */ - @Experimental def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V]]( jssc: JavaStreamingContext, keyClass: Class[K], @@ -564,7 +546,7 @@ object KafkaUtils { * classOf[KafkaUtilsPythonHelper].newInstance(), and the createStream() * takes care of known parameters instead of passing them from Python */ -private class KafkaUtilsPythonHelper { +private[kafka] class KafkaUtilsPythonHelper { def createStream( jssc: JavaStreamingContext, kafkaParams: JMap[String, String], diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala index 2f8981d4898bd..8a5f371494511 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala @@ -19,10 +19,7 @@ package org.apache.spark.streaming.kafka import kafka.common.TopicAndPartition -import org.apache.spark.annotation.Experimental - /** - * :: Experimental :: * Represents any object that has a collection of [[OffsetRange]]s. This can be used access the * offset ranges in RDDs generated by the direct Kafka DStream (see * [[KafkaUtils.createDirectStream()]]). @@ -33,13 +30,11 @@ import org.apache.spark.annotation.Experimental * } * }}} */ -@Experimental trait HasOffsetRanges { def offsetRanges: Array[OffsetRange] } /** - * :: Experimental :: * Represents a range of offsets from a single Kafka TopicAndPartition. Instances of this class * can be created with `OffsetRange.create()`. * @param topic Kafka topic name @@ -47,7 +42,6 @@ trait HasOffsetRanges { * @param fromOffset Inclusive starting offset * @param untilOffset Exclusive ending offset */ -@Experimental final class OffsetRange private( val topic: String, val partition: Int, @@ -84,10 +78,8 @@ final class OffsetRange private( } /** - * :: Experimental :: * Companion object the provides methods to create instances of [[OffsetRange]]. */ -@Experimental object OffsetRange { def create(topic: String, partition: Int, fromOffset: Long, untilOffset: Long): OffsetRange = new OffsetRange(topic, partition, fromOffset, untilOffset) diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala index 38a1114863d15..7b8d56d6faf2d 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala @@ -21,8 +21,8 @@ import scala.reflect.ClassTag import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext, JavaDStream} -import org.apache.spark.streaming.dstream.{ReceiverInputDStream, DStream} +import org.apache.spark.streaming.api.java.{JavaDStream, JavaReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.dstream.ReceiverInputDStream object MQTTUtils { /** @@ -79,7 +79,7 @@ object MQTTUtils { * This is a helper class that wraps the methods in MQTTUtils into more Python-friendly class and * function so that it can be easily instantiated and called from Python's MQTTUtils. */ -private class MQTTUtilsPythonHelper { +private[mqtt] class MQTTUtilsPythonHelper { def createStream( jssc: JavaStreamingContext, diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala index 1a371b7008824..1618e2c088b70 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala @@ -33,7 +33,7 @@ import org.apache.spark.{Logging, SparkConf} /** * Share codes for Scala and Python unit tests */ -private class MQTTTestUtils extends Logging { +private[mqtt] class MQTTTestUtils extends Logging { private val persistenceDir = Utils.createTempDir() private val brokerHost = "localhost" diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index 711aade182945..c8eec13ec7dc7 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -36,7 +36,7 @@ import org.apache.spark.Logging /** * Shared utility methods for performing Kinesis tests that actually transfer data */ -private class KinesisTestUtils extends Logging { +private[kinesis] class KinesisTestUtils extends Logging { val endpointUrl = KinesisTestUtils.endpointUrl val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName() diff --git a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java index 8c0fdfa9c7478..3738fc1a235c2 100644 --- a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java +++ b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java @@ -21,6 +21,8 @@ import java.util.Iterator; /** + * :: DeveloperApi :: + * * This abstract class represents a write ahead log (aka journal) that is used by Spark Streaming * to save the received data (by receivers) and associated metadata to a reliable storage, so that * they can be recovered after driver failures. See the Spark documentation for more information diff --git a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java index 02324189b7822..662889e779fb2 100644 --- a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java +++ b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java @@ -18,6 +18,8 @@ package org.apache.spark.streaming.util; /** + * :: DeveloperApi :: + * * This abstract class represents a handle that refers to a record written in a * {@link org.apache.spark.streaming.util.WriteAheadLog WriteAheadLog}. * It must contain all the information necessary for the record to be read and returned by diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index c8dd6e06812dc..5f6c5b024085c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -222,7 +222,7 @@ private[streaming] object WriteAheadLogBasedBlockHandler { /** * A utility that will wrap the Iterator to get the count */ -private class CountingIterator[T](iterator: Iterator[T]) extends Iterator[T] { +private[streaming] class CountingIterator[T](iterator: Iterator[T]) extends Iterator[T] { private var _count = 0 private def isFullyConsumed: Boolean = !iterator.hasNext diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index aae3acf7aba3e..30d25a64e307a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -546,7 +546,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false * Function to start the receiver on the worker node. Use a class instead of closure to avoid * the serialization issue. */ -private class StartReceiverFunc( +private[streaming] class StartReceiverFunc( checkpointDirOption: Option[String], serializableHadoopConf: SerializableConfiguration) extends (Iterator[Receiver[_]] => Unit) with Serializable { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index 0c891662c264f..90d1b0fadecfc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -28,7 +28,7 @@ import org.apache.spark.ui.{UIUtils => SparkUIUtils, WebUIPage} import org.apache.spark.streaming.ui.StreamingJobProgressListener.{SparkJobId, OutputOpId} import org.apache.spark.ui.jobs.UIData.JobUIData -private case class SparkJobIdWithUIData(sparkJobId: SparkJobId, jobUIData: Option[JobUIData]) +private[ui] case class SparkJobIdWithUIData(sparkJobId: SparkJobId, jobUIData: Option[JobUIData]) private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { private val streamingListener = parent.listener From 9ce0c7ad333f4a3c01207e5e9ed42bcafb99d894 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 24 Aug 2015 13:48:01 -0700 Subject: [PATCH 047/802] [SPARK-7710] [SPARK-7998] [DOCS] Docs for DataFrameStatFunctions This PR contains examples on how to use some of the Stat Functions available for DataFrames under `df.stat`. rxin Author: Burak Yavuz Closes #8378 from brkyvz/update-sql-docs. --- .../org/apache/spark/sql/DataFrame.scala | 2 +- .../spark/sql/DataFrameStatFunctions.scala | 101 ++++++++++++++++++ 2 files changed, 102 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index d6688b24ae7d6..791c10c3d7ce7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -684,7 +684,7 @@ class DataFrame private[sql]( // make it a NamedExpression. case Column(u: UnresolvedAttribute) => UnresolvedAlias(u) case Column(expr: NamedExpression) => expr - // Leave an unaliased explode with an empty list of names since the analzyer will generate the + // Leave an unaliased explode with an empty list of names since the analyzer will generate the // correct defaults after the nested expression's type has been resolved. case Column(explode: Explode) => MultiAlias(explode, Nil) case Column(expr: Expression) => Alias(expr, expr.prettyString)() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 2e68e358f2f1f..69c984717526d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -39,6 +39,13 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @param col2 the name of the second column * @return the covariance of the two columns. * + * {{{ + * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) + * .withColumn("rand2", rand(seed=27)) + * df.stat.cov("rand1", "rand2") + * res1: Double = 0.065... + * }}} + * * @since 1.4.0 */ def cov(col1: String, col2: String): Double = { @@ -54,6 +61,13 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @param col2 the name of the column to calculate the correlation against * @return The Pearson Correlation Coefficient as a Double. * + * {{{ + * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) + * .withColumn("rand2", rand(seed=27)) + * df.stat.corr("rand1", "rand2") + * res1: Double = 0.613... + * }}} + * * @since 1.4.0 */ def corr(col1: String, col2: String, method: String): Double = { @@ -69,6 +83,13 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @param col2 the name of the column to calculate the correlation against * @return The Pearson Correlation Coefficient as a Double. * + * {{{ + * val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10)) + * .withColumn("rand2", rand(seed=27)) + * df.stat.corr("rand1", "rand2", "pearson") + * res1: Double = 0.613... + * }}} + * * @since 1.4.0 */ def corr(col1: String, col2: String): Double = { @@ -92,6 +113,20 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * of the DataFrame. * @return A DataFrame containing for the contingency table. * + * {{{ + * val df = sqlContext.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), + * (3, 3))).toDF("key", "value") + * val ct = df.stat.crosstab("key", "value") + * ct.show() + * +---------+---+---+---+ + * |key_value| 1| 2| 3| + * +---------+---+---+---+ + * | 2| 2| 0| 1| + * | 1| 1| 1| 0| + * | 3| 0| 1| 1| + * +---------+---+---+---+ + * }}} + * * @since 1.4.0 */ def crosstab(col1: String, col2: String): DataFrame = { @@ -112,6 +147,32 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * than 1e-4. * @return A Local DataFrame with the Array of frequent items for each column. * + * {{{ + * val rows = Seq.tabulate(100) { i => + * if (i % 2 == 0) (1, -1.0) else (i, i * -1.0) + * } + * val df = sqlContext.createDataFrame(rows).toDF("a", "b") + * // find the items with a frequency greater than 0.4 (observed 40% of the time) for columns + * // "a" and "b" + * val freqSingles = df.stat.freqItems(Array("a", "b"), 0.4) + * freqSingles.show() + * +-----------+-------------+ + * |a_freqItems| b_freqItems| + * +-----------+-------------+ + * | [1, 99]|[-1.0, -99.0]| + * +-----------+-------------+ + * // find the pair of items with a frequency greater than 0.1 in columns "a" and "b" + * val pairDf = df.select(struct("a", "b").as("a-b")) + * val freqPairs = pairDf.stat.freqItems(Array("a-b"), 0.1) + * freqPairs.select(explode($"a-b_freqItems").as("freq_ab")).show() + * +----------+ + * | freq_ab| + * +----------+ + * | [1,-1.0]| + * | ... | + * +----------+ + * }}} + * * @since 1.4.0 */ def freqItems(cols: Array[String], support: Double): DataFrame = { @@ -147,6 +208,32 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @param cols the names of the columns to search frequent items in. * @return A Local DataFrame with the Array of frequent items for each column. * + * {{{ + * val rows = Seq.tabulate(100) { i => + * if (i % 2 == 0) (1, -1.0) else (i, i * -1.0) + * } + * val df = sqlContext.createDataFrame(rows).toDF("a", "b") + * // find the items with a frequency greater than 0.4 (observed 40% of the time) for columns + * // "a" and "b" + * val freqSingles = df.stat.freqItems(Seq("a", "b"), 0.4) + * freqSingles.show() + * +-----------+-------------+ + * |a_freqItems| b_freqItems| + * +-----------+-------------+ + * | [1, 99]|[-1.0, -99.0]| + * +-----------+-------------+ + * // find the pair of items with a frequency greater than 0.1 in columns "a" and "b" + * val pairDf = df.select(struct("a", "b").as("a-b")) + * val freqPairs = pairDf.stat.freqItems(Seq("a-b"), 0.1) + * freqPairs.select(explode($"a-b_freqItems").as("freq_ab")).show() + * +----------+ + * | freq_ab| + * +----------+ + * | [1,-1.0]| + * | ... | + * +----------+ + * }}} + * * @since 1.4.0 */ def freqItems(cols: Seq[String], support: Double): DataFrame = { @@ -180,6 +267,20 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @tparam T stratum type * @return a new [[DataFrame]] that represents the stratified sample * + * {{{ + * val df = sqlContext.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), + * (3, 3))).toDF("key", "value") + * val fractions = Map(1 -> 1.0, 3 -> 0.5) + * df.stat.sampleBy("key", fractions, 36L).show() + * +---+-----+ + * |key|value| + * +---+-----+ + * | 1| 1| + * | 1| 2| + * | 3| 2| + * +---+-----+ + * }}} + * * @since 1.5.0 */ def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = { From 662bb9667669cb07cf6d2ccee0d8e76bb561cd89 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 24 Aug 2015 14:10:50 -0700 Subject: [PATCH 048/802] [SPARK-10144] [UI] Actually show peak execution memory by default The peak execution memory metric was introduced in SPARK-8735. That was before Tungsten was enabled by default, so it assumed that `spark.sql.unsafe.enabled` must be explicitly set to true. The result is that the memory is not displayed by default. Author: Andrew Or Closes #8345 from andrewor14/show-memory-default. --- .../main/scala/org/apache/spark/ui/jobs/StagePage.scala | 6 ++---- .../test/scala/org/apache/spark/ui/StagePageSuite.scala | 8 ++++++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index fb4556b836859..4adc6596ba21c 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -68,8 +68,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { // if we find that it's okay. private val MAX_TIMELINE_TASKS = parent.conf.getInt("spark.ui.timeline.tasks.maximum", 1000) - private val displayPeakExecutionMemory = - parent.conf.getOption("spark.sql.unsafe.enabled").exists(_.toBoolean) + private val displayPeakExecutionMemory = parent.conf.getBoolean("spark.sql.unsafe.enabled", true) def render(request: HttpServletRequest): Seq[Node] = { progressListener.synchronized { @@ -1193,8 +1192,7 @@ private[ui] class TaskPagedTable( desc: Boolean) extends PagedTable[TaskTableRowData] { // We only track peak memory used for unsafe operators - private val displayPeakExecutionMemory = - conf.getOption("spark.sql.unsafe.enabled").exists(_.toBoolean) + private val displayPeakExecutionMemory = conf.getBoolean("spark.sql.unsafe.enabled", true) override def tableId: String = "task-table" diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 98f9314f31dff..3388c6dca81f1 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -33,14 +33,18 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { test("peak execution memory only displayed if unsafe is enabled") { val unsafeConf = "spark.sql.unsafe.enabled" - val conf = new SparkConf().set(unsafeConf, "true") + val conf = new SparkConf(false).set(unsafeConf, "true") val html = renderStagePage(conf).toString().toLowerCase val targetString = "peak execution memory" assert(html.contains(targetString)) // Disable unsafe and make sure it's not there - val conf2 = new SparkConf().set(unsafeConf, "false") + val conf2 = new SparkConf(false).set(unsafeConf, "false") val html2 = renderStagePage(conf2).toString().toLowerCase assert(!html2.contains(targetString)) + // Avoid setting anything; it should be displayed by default + val conf3 = new SparkConf(false) + val html3 = renderStagePage(conf3).toString().toLowerCase + assert(html3.contains(targetString)) } /** From a2f4cdceba32aaa0df59df335ca0ce1ac73fc6c2 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 24 Aug 2015 14:11:19 -0700 Subject: [PATCH 049/802] [SPARK-8580] [SQL] Refactors ParquetHiveCompatibilitySuite and adds more test cases This PR refactors `ParquetHiveCompatibilitySuite` so that it's easier to add new test cases. Hit two bugs, SPARK-10177 and HIVE-11625, while working on this, added test cases for them and marked as ignored for now. SPARK-10177 will be addressed in a separate PR. Author: Cheng Lian Closes #8392 from liancheng/spark-8580/parquet-hive-compat-tests. --- .../hive/ParquetHiveCompatibilitySuite.scala | 132 ++++++++++++------ 1 file changed, 93 insertions(+), 39 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index 13452e71a1b3b..bc30180cf0917 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -17,15 +17,17 @@ package org.apache.spark.sql.hive +import java.sql.Timestamp +import java.util.{Locale, TimeZone} + import org.apache.hadoop.hive.conf.HiveConf +import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.execution.datasources.parquet.ParquetCompatibilityTest +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.{Row, SQLConf, SQLContext} -class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest { - import ParquetCompatibilityTest.makeNullable - +class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with BeforeAndAfterAll { override def _sqlContext: SQLContext = TestHive private val sqlContext = _sqlContext @@ -35,69 +37,121 @@ class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest { */ private val stagingDir = new HiveConf().getVar(HiveConf.ConfVars.STAGINGDIR) - test("Read Parquet file generated by parquet-hive") { + private val originalTimeZone = TimeZone.getDefault + private val originalLocale = Locale.getDefault + + protected override def beforeAll(): Unit = { + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + Locale.setDefault(Locale.US) + } + + override protected def afterAll(): Unit = { + TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) + } + + override protected def logParquetSchema(path: String): Unit = { + val schema = readParquetSchema(path, { path => + !path.getName.startsWith("_") && !path.getName.startsWith(stagingDir) + }) + + logInfo( + s"""Schema of the Parquet file written by parquet-avro: + |$schema + """.stripMargin) + } + + private def testParquetHiveCompatibility(row: Row, hiveTypes: String*): Unit = { withTable("parquet_compat") { withTempPath { dir => val path = dir.getCanonicalPath + // Hive columns are always nullable, so here we append a all-null row. + val rows = row :: Row(Seq.fill(row.length)(null): _*) :: Nil + + // Don't convert Hive metastore Parquet tables to let Hive write those Parquet files. withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { withTempTable("data") { - sqlContext.sql( + val fields = hiveTypes.zipWithIndex.map { case (typ, index) => s" col_$index $typ" } + + val ddl = s"""CREATE TABLE parquet_compat( - | bool_column BOOLEAN, - | byte_column TINYINT, - | short_column SMALLINT, - | int_column INT, - | long_column BIGINT, - | float_column FLOAT, - | double_column DOUBLE, - | - | strings_column ARRAY, - | int_to_string_column MAP + |${fields.mkString(",\n")} |) |STORED AS PARQUET |LOCATION '$path' + """.stripMargin + + logInfo( + s"""Creating testing Parquet table with the following DDL: + |$ddl """.stripMargin) + sqlContext.sql(ddl) + val schema = sqlContext.table("parquet_compat").schema - val rowRDD = sqlContext.sparkContext.parallelize(makeRows).coalesce(1) + val rowRDD = sqlContext.sparkContext.parallelize(rows).coalesce(1) sqlContext.createDataFrame(rowRDD, schema).registerTempTable("data") sqlContext.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data") } } - val schema = readParquetSchema(path, { path => - !path.getName.startsWith("_") && !path.getName.startsWith(stagingDir) - }) - - logInfo( - s"""Schema of the Parquet file written by parquet-hive: - |$schema - """.stripMargin) + logParquetSchema(path) // Unfortunately parquet-hive doesn't add `UTF8` annotation to BINARY when writing strings. // Have to assume all BINARY values are strings here. withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { - checkAnswer(sqlContext.read.parquet(path), makeRows) + checkAnswer(sqlContext.read.parquet(path), rows) } } } } - def makeRows: Seq[Row] = { - (0 until 10).map { i => - def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i) + test("simple primitives") { + testParquetHiveCompatibility( + Row(true, 1.toByte, 2.toShort, 3, 4.toLong, 5.1f, 6.1d, "foo"), + "BOOLEAN", "TINYINT", "SMALLINT", "INT", "BIGINT", "FLOAT", "DOUBLE", "STRING") + } + ignore("SPARK-10177 timestamp") { + testParquetHiveCompatibility(Row(Timestamp.valueOf("2015-08-24 00:31:00")), "TIMESTAMP") + } + + test("array") { + testParquetHiveCompatibility( Row( - nullable(i % 2 == 0: java.lang.Boolean), - nullable(i.toByte: java.lang.Byte), - nullable((i + 1).toShort: java.lang.Short), - nullable(i + 2: Integer), - nullable(i.toLong * 10: java.lang.Long), - nullable(i.toFloat + 0.1f: java.lang.Float), - nullable(i.toDouble + 0.2d: java.lang.Double), - nullable(Seq.tabulate(3)(n => s"arr_${i + n}")), - nullable(Seq.tabulate(3)(n => (i + n: Integer) -> s"val_${i + n}").toMap)) - } + Seq[Integer](1: Integer, null, 2: Integer, null), + Seq[String]("foo", null, "bar", null), + Seq[Seq[Integer]]( + Seq[Integer](1: Integer, null), + Seq[Integer](2: Integer, null))), + "ARRAY", + "ARRAY", + "ARRAY>") + } + + test("map") { + testParquetHiveCompatibility( + Row( + Map[Integer, String]( + (1: Integer) -> "foo", + (2: Integer) -> null)), + "MAP") + } + + // HIVE-11625: Parquet map entries with null keys are dropped by Hive + ignore("map entries with null keys") { + testParquetHiveCompatibility( + Row( + Map[Integer, String]( + null.asInstanceOf[Integer] -> "bar", + null.asInstanceOf[Integer] -> null)), + "MAP") + } + + test("struct") { + testParquetHiveCompatibility( + Row(Row(1, Seq("foo", "bar", null))), + "STRUCT>") } } From cb2d2e15844d7ae34b5dd7028b55e11586ed93fa Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 24 Aug 2015 22:35:21 +0100 Subject: [PATCH 050/802] [SPARK-9758] [TEST] [SQL] Compilation issue for hive test / wrong package? Move `test.org.apache.spark.sql.hive` package tests to apparent intended `org.apache.spark.sql.hive` as they don't intend to test behavior from outside org.apache.spark.* Alternate take, per discussion at https://github.com/apache/spark/pull/8051 I think this is what vanzin and I had in mind but also CC rxin to cross-check, as this does indeed depend on whether these tests were accidentally in this package or not. Testing from a `test.org.apache.spark` package is legitimate but didn't seem to be the intent here. Author: Sean Owen Closes #8307 from srowen/SPARK-9758. --- .../org/apache/spark/sql/hive/JavaDataFrameSuite.java | 6 ++---- .../spark/sql/hive/JavaMetastoreDataSourcesSuite.java | 3 +-- .../org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java | 2 +- .../org/apache/spark/sql/hive/aggregate/MyDoubleSum.java | 2 +- .../apache/spark/sql/hive/execution/UDFIntegerToString.java | 0 .../org/apache/spark/sql/hive/execution/UDFListListInt.java | 0 .../org/apache/spark/sql/hive/execution/UDFListString.java | 0 .../apache/spark/sql/hive/execution/UDFStringString.java | 0 .../org/apache/spark/sql/hive/execution/UDFTwoListList.java | 0 .../spark/sql/hive/execution/AggregationQuerySuite.scala | 2 +- 10 files changed, 6 insertions(+), 9 deletions(-) rename sql/hive/src/test/java/{test => }/org/apache/spark/sql/hive/JavaDataFrameSuite.java (94%) rename sql/hive/src/test/java/{test => }/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java (98%) rename sql/hive/src/test/java/{test => }/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java (98%) rename sql/hive/src/test/java/{test => }/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java (98%) rename sql/hive/src/test/java/{test => }/org/apache/spark/sql/hive/execution/UDFIntegerToString.java (100%) rename sql/hive/src/test/java/{test => }/org/apache/spark/sql/hive/execution/UDFListListInt.java (100%) rename sql/hive/src/test/java/{test => }/org/apache/spark/sql/hive/execution/UDFListString.java (100%) rename sql/hive/src/test/java/{test => }/org/apache/spark/sql/hive/execution/UDFStringString.java (100%) rename sql/hive/src/test/java/{test => }/org/apache/spark/sql/hive/execution/UDFTwoListList.java (100%) diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java similarity index 94% rename from sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java rename to sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java index a30dfa554eabc..019d8a30266e2 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package test.org.apache.spark.sql.hive; +package org.apache.spark.sql.hive; import java.io.IOException; import java.util.ArrayList; @@ -31,10 +31,8 @@ import org.apache.spark.sql.expressions.Window; import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import static org.apache.spark.sql.functions.*; -import org.apache.spark.sql.hive.HiveContext; import org.apache.spark.sql.hive.test.TestHive$; -import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; -import test.org.apache.spark.sql.hive.aggregate.MyDoubleSum; +import org.apache.spark.sql.hive.aggregate.MyDoubleSum; public class JavaDataFrameSuite { private transient JavaSparkContext sc; diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java similarity index 98% rename from sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java rename to sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 15c2c3deb0d83..4192155975c47 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package test.org.apache.spark.sql.hive; +package org.apache.spark.sql.hive; import java.io.File; import java.io.IOException; @@ -37,7 +37,6 @@ import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.QueryTest$; import org.apache.spark.sql.Row; -import org.apache.spark.sql.hive.HiveContext; import org.apache.spark.sql.hive.test.TestHive$; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java similarity index 98% rename from sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java rename to sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java index 2961b803f14aa..5a167edd89592 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package test.org.apache.spark.sql.hive.aggregate; +package org.apache.spark.sql.hive.aggregate; import java.util.ArrayList; import java.util.List; diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java similarity index 98% rename from sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java rename to sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java index c71882a6e7bed..c3b7768e71bf8 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package test.org.apache.spark.sql.hive.aggregate; +package org.apache.spark.sql.hive.aggregate; import java.util.ArrayList; import java.util.List; diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFIntegerToString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java similarity index 100% rename from sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFIntegerToString.java rename to sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListListInt.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java similarity index 100% rename from sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListListInt.java rename to sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java similarity index 100% rename from sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListString.java rename to sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFStringString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java similarity index 100% rename from sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFStringString.java rename to sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFTwoListList.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java similarity index 100% rename from sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFTwoListList.java rename to sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 119663af1887a..4886a85948367 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -import _root_.test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} +import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { override def _sqlContext: SQLContext = TestHive From 13db11cb08eb90eb0ea3402c9fe0270aa282f247 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 24 Aug 2015 15:38:54 -0700 Subject: [PATCH 051/802] [SPARK-10061] [DOC] ML ensemble docs User guide for spark.ml GBTs and Random Forests. The examples are copied from the decision tree guide and modified to run. I caught some issues I had somehow missed in the tree guide as well. I have run all examples, including Java ones. (Of course, I thought I had previously as well...) CC: mengxr manishamde yanboliang Author: Joseph K. Bradley Closes #8369 from jkbradley/ml-ensemble-docs. --- docs/ml-decision-tree.md | 75 ++- docs/ml-ensembles.md | 952 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 976 insertions(+), 51 deletions(-) diff --git a/docs/ml-decision-tree.md b/docs/ml-decision-tree.md index 958c6f5e4716c..542819e93e6dc 100644 --- a/docs/ml-decision-tree.md +++ b/docs/ml-decision-tree.md @@ -30,7 +30,7 @@ The Pipelines API for Decision Trees offers a bit more functionality than the or Ensembles of trees (Random Forests and Gradient-Boosted Trees) are described in the [Ensembles guide](ml-ensembles.html). -# Inputs and Outputs (Predictions) +# Inputs and Outputs We list the input and output (prediction) column types here. All output columns are optional; to exclude an output column, set its corresponding Param to an empty string. @@ -234,7 +234,7 @@ IndexToString labelConverter = new IndexToString() // Chain indexers and tree in a Pipeline Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[]{labelIndexer, featureIndexer, dt, labelConverter}); + .setStages(new PipelineStage[] {labelIndexer, featureIndexer, dt, labelConverter}); // Train model. This also runs the indexers. PipelineModel model = pipeline.fit(trainingData); @@ -315,10 +315,13 @@ print treeModel # summary only ## Regression +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use a feature transformer to index categorical features, adding metadata to the `DataFrame` which the Decision Tree algorithm can recognize. +
-More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.classification.DecisionTreeClassifier). +More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.regression.DecisionTreeRegressor). {% highlight scala %} import org.apache.spark.ml.Pipeline @@ -347,7 +350,7 @@ val dt = new DecisionTreeRegressor() .setLabelCol("label") .setFeaturesCol("indexedFeatures") -// Chain indexers and tree in a Pipeline +// Chain indexer and tree in a Pipeline val pipeline = new Pipeline() .setStages(Array(featureIndexer, dt)) @@ -365,9 +368,7 @@ val evaluator = new RegressionEvaluator() .setLabelCol("label") .setPredictionCol("prediction") .setMetricName("rmse") -// We negate the RMSE value since RegressionEvalutor returns negated RMSE -// (since evaluation metrics are meant to be maximized by CrossValidator). -val rmse = - evaluator.evaluate(predictions) +val rmse = evaluator.evaluate(predictions) println("Root Mean Squared Error (RMSE) on test data = " + rmse) val treeModel = model.stages(1).asInstanceOf[DecisionTreeRegressionModel] @@ -377,14 +378,15 @@ println("Learned regression tree model:\n" + treeModel.toDebugString)
-More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/classification/DecisionTreeClassifier.html). +More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/regression/DecisionTreeRegressor.html). {% highlight java %} import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineModel; import org.apache.spark.ml.PipelineStage; import org.apache.spark.ml.evaluation.RegressionEvaluator; -import org.apache.spark.ml.feature.*; +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; import org.apache.spark.ml.regression.DecisionTreeRegressionModel; import org.apache.spark.ml.regression.DecisionTreeRegressor; import org.apache.spark.mllib.regression.LabeledPoint; @@ -396,17 +398,12 @@ import org.apache.spark.sql.DataFrame; RDD rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt"); DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class); -// Index labels, adding metadata to the label column. -// Fit on whole dataset to include all labels in index. -StringIndexerModel labelIndexer = new StringIndexer() - .setInputCol("label") - .setOutputCol("indexedLabel") - .fit(data); // Automatically identify categorical features, and index them. +// Set maxCategories so features with > 4 distinct values are treated as continuous. VectorIndexerModel featureIndexer = new VectorIndexer() .setInputCol("features") .setOutputCol("indexedFeatures") - .setMaxCategories(4) // features with > 4 distinct values are treated as continuous + .setMaxCategories(4) .fit(data); // Split the data into training and test sets (30% held out for testing) @@ -416,61 +413,49 @@ DataFrame testData = splits[1]; // Train a DecisionTree model. DecisionTreeRegressor dt = new DecisionTreeRegressor() - .setLabelCol("indexedLabel") .setFeaturesCol("indexedFeatures"); -// Convert indexed labels back to original labels. -IndexToString labelConverter = new IndexToString() - .setInputCol("prediction") - .setOutputCol("predictedLabel") - .setLabels(labelIndexer.labels()); - -// Chain indexers and tree in a Pipeline +// Chain indexer and tree in a Pipeline Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[]{labelIndexer, featureIndexer, dt, labelConverter}); + .setStages(new PipelineStage[] {featureIndexer, dt}); -// Train model. This also runs the indexers. +// Train model. This also runs the indexer. PipelineModel model = pipeline.fit(trainingData); // Make predictions. DataFrame predictions = model.transform(testData); // Select example rows to display. -predictions.select("predictedLabel", "label", "features").show(5); +predictions.select("label", "features").show(5); // Select (prediction, true label) and compute test error RegressionEvaluator evaluator = new RegressionEvaluator() - .setLabelCol("indexedLabel") + .setLabelCol("label") .setPredictionCol("prediction") .setMetricName("rmse"); -// We negate the RMSE value since RegressionEvalutor returns negated RMSE -// (since evaluation metrics are meant to be maximized by CrossValidator). -double rmse = - evaluator.evaluate(predictions); +double rmse = evaluator.evaluate(predictions); System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); DecisionTreeRegressionModel treeModel = - (DecisionTreeRegressionModel)(model.stages()[2]); + (DecisionTreeRegressionModel)(model.stages()[1]); System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); {% endhighlight %}
-More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.classification.DecisionTreeClassifier). +More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.regression.DecisionTreeRegressor). {% highlight python %} from pyspark.ml import Pipeline from pyspark.ml.regression import DecisionTreeRegressor -from pyspark.ml.feature import StringIndexer, VectorIndexer +from pyspark.ml.feature import VectorIndexer from pyspark.ml.evaluation import RegressionEvaluator from pyspark.mllib.util import MLUtils # Load and parse the data file, converting it to a DataFrame. data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() -# Index labels, adding metadata to the label column. -# Fit on whole dataset to include all labels in index. -labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) # Automatically identify categorical features, and index them. # We specify maxCategories so features with > 4 distinct values are treated as continuous. featureIndexer =\ @@ -480,26 +465,24 @@ featureIndexer =\ (trainingData, testData) = data.randomSplit([0.7, 0.3]) # Train a DecisionTree model. -dt = DecisionTreeRegressor(labelCol="indexedLabel", featuresCol="indexedFeatures") +dt = DecisionTreeRegressor(featuresCol="indexedFeatures") -# Chain indexers and tree in a Pipeline -pipeline = Pipeline(stages=[labelIndexer, featureIndexer, dt]) +# Chain indexer and tree in a Pipeline +pipeline = Pipeline(stages=[featureIndexer, dt]) -# Train model. This also runs the indexers. +# Train model. This also runs the indexer. model = pipeline.fit(trainingData) # Make predictions. predictions = model.transform(testData) # Select example rows to display. -predictions.select("prediction", "indexedLabel", "features").show(5) +predictions.select("prediction", "label", "features").show(5) # Select (prediction, true label) and compute test error evaluator = RegressionEvaluator( - labelCol="indexedLabel", predictionCol="prediction", metricName="rmse") -# We negate the RMSE value since RegressionEvalutor returns negated RMSE -# (since evaluation metrics are meant to be maximized by CrossValidator). -rmse = -evaluator.evaluate(predictions) + labelCol="label", predictionCol="prediction", metricName="rmse") +rmse = evaluator.evaluate(predictions) print "Root Mean Squared Error (RMSE) on test data = %g" % rmse treeModel = model.stages[1] diff --git a/docs/ml-ensembles.md b/docs/ml-ensembles.md index 9ff50e95fc479..62749909e01dc 100644 --- a/docs/ml-ensembles.md +++ b/docs/ml-ensembles.md @@ -11,11 +11,947 @@ displayTitle: ML - Ensembles An [ensemble method](http://en.wikipedia.org/wiki/Ensemble_learning) is a learning algorithm which creates a model composed of a set of other base models. -The Pipelines API supports the following ensemble algorithms: [`OneVsRest`](api/scala/index.html#org.apache.spark.ml.classifier.OneVsRest) -## OneVsRest +## Tree Ensembles -[OneVsRest](http://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest) is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. +The Pipelines API supports two major tree ensemble algorithms: [Random Forests](http://en.wikipedia.org/wiki/Random_forest) and [Gradient-Boosted Trees (GBTs)](http://en.wikipedia.org/wiki/Gradient_boosting). +Both use [MLlib decision trees](ml-decision-tree.html) as their base models. + +Users can find more information about ensemble algorithms in the [MLlib Ensemble guide](mllib-ensembles.html). In this section, we demonstrate the Pipelines API for ensembles. + +The main differences between this API and the [original MLlib ensembles API](mllib-ensembles.html) are: +* support for ML Pipelines +* separation of classification vs. regression +* use of DataFrame metadata to distinguish continuous and categorical features +* a bit more functionality for random forests: estimates of feature importance, as well as the predicted probability of each class (a.k.a. class conditional probabilities) for classification. + +### Random Forests + +[Random forests](http://en.wikipedia.org/wiki/Random_forest) +are ensembles of [decision trees](ml-decision-tree.html). +Random forests combine many decision trees in order to reduce the risk of overfitting. +MLlib supports random forests for binary and multiclass classification and for regression, +using both continuous and categorical features. + +This section gives examples of using random forests with the Pipelines API. +For more information on the algorithm, please see the [main MLlib docs on random forests](mllib-ensembles.html). + +#### Inputs and Outputs + +We list the input and output (prediction) column types here. +All output columns are optional; to exclude an output column, set its corresponding Param to an empty string. + +##### Input Columns + +
+ + + + + + + + + + + + + + + + + + + + + + +
Param nameType(s)DefaultDescription
labelColDouble"label"Label to predict
featuresColVector"features"Feature vector
+ +##### Output Columns (Predictions) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Param nameType(s)DefaultDescriptionNotes
predictionColDouble"prediction"Predicted label
rawPredictionColVector"rawPrediction"Vector of length # classes, with the counts of training instance labels at the tree node which makes the predictionClassification only
probabilityColVector"probability"Vector of length # classes equal to rawPrediction normalized to a multinomial distributionClassification only
+ +#### Example: Classification + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. + +
+
+ +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.RandomForestClassifier) for more details. + +{% highlight scala %} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.RandomForestClassifier +import org.apache.spark.ml.classification.RandomForestClassificationModel +import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +import org.apache.spark.mllib.util.MLUtils + +// Load and parse the data file, converting it to a DataFrame. +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +// Index labels, adding metadata to the label column. +// Fit on whole dataset to include all labels in index. +val labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data) +// Automatically identify categorical features, and index them. +// Set maxCategories so features with > 4 distinct values are treated as continuous. +val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + +// Split the data into training and test sets (30% held out for testing) +val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + +// Train a RandomForest model. +val rf = new RandomForestClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures") + .setNumTrees(10) + +// Convert indexed labels back to original labels. +val labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels) + +// Chain indexers and forest in a Pipeline +val pipeline = new Pipeline() + .setStages(Array(labelIndexer, featureIndexer, rf, labelConverter)) + +// Train model. This also runs the indexers. +val model = pipeline.fit(trainingData) + +// Make predictions. +val predictions = model.transform(testData) + +// Select example rows to display. +predictions.select("predictedLabel", "label", "features").show(5) + +// Select (prediction, true label) and compute test error +val evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision") +val accuracy = evaluator.evaluate(predictions) +println("Test Error = " + (1.0 - accuracy)) + +val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel] +println("Learned classification forest model:\n" + rfModel.toDebugString) +{% endhighlight %} +
+ +
+ +Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/RandomForestClassifier.html) for more details. + +{% highlight java %} +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.RandomForestClassifier; +import org.apache.spark.ml.classification.RandomForestClassificationModel; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; +import org.apache.spark.ml.feature.*; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.DataFrame; + +// Load and parse the data file, converting it to a DataFrame. +RDD rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt"); +DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class); + +// Index labels, adding metadata to the label column. +// Fit on whole dataset to include all labels in index. +StringIndexerModel labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data); +// Automatically identify categorical features, and index them. +// Set maxCategories so features with > 4 distinct values are treated as continuous. +VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data); + +// Split the data into training and test sets (30% held out for testing) +DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); +DataFrame trainingData = splits[0]; +DataFrame testData = splits[1]; + +// Train a RandomForest model. +RandomForestClassifier rf = new RandomForestClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures"); + +// Convert indexed labels back to original labels. +IndexToString labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels()); + +// Chain indexers and forest in a Pipeline +Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {labelIndexer, featureIndexer, rf, labelConverter}); + +// Train model. This also runs the indexers. +PipelineModel model = pipeline.fit(trainingData); + +// Make predictions. +DataFrame predictions = model.transform(testData); + +// Select example rows to display. +predictions.select("predictedLabel", "label", "features").show(5); + +// Select (prediction, true label) and compute test error +MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision"); +double accuracy = evaluator.evaluate(predictions); +System.out.println("Test Error = " + (1.0 - accuracy)); + +RandomForestClassificationModel rfModel = + (RandomForestClassificationModel)(model.stages()[2]); +System.out.println("Learned classification forest model:\n" + rfModel.toDebugString()); +{% endhighlight %} +
+ +
+ +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.RandomForestClassifier) for more details. + +{% highlight python %} +from pyspark.ml import Pipeline +from pyspark.ml.classification import RandomForestClassifier +from pyspark.ml.feature import StringIndexer, VectorIndexer +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +from pyspark.mllib.util import MLUtils + +# Load and parse the data file, converting it to a DataFrame. +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +# Index labels, adding metadata to the label column. +# Fit on whole dataset to include all labels in index. +labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) +# Automatically identify categorical features, and index them. +# Set maxCategories so features with > 4 distinct values are treated as continuous. +featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + +# Split the data into training and test sets (30% held out for testing) +(trainingData, testData) = data.randomSplit([0.7, 0.3]) + +# Train a RandomForest model. +rf = RandomForestClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures") + +# Chain indexers and forest in a Pipeline +pipeline = Pipeline(stages=[labelIndexer, featureIndexer, rf]) + +# Train model. This also runs the indexers. +model = pipeline.fit(trainingData) + +# Make predictions. +predictions = model.transform(testData) + +# Select example rows to display. +predictions.select("prediction", "indexedLabel", "features").show(5) + +# Select (prediction, true label) and compute test error +evaluator = MulticlassClassificationEvaluator( + labelCol="indexedLabel", predictionCol="prediction", metricName="precision") +accuracy = evaluator.evaluate(predictions) +print "Test Error = %g" % (1.0 - accuracy) + +rfModel = model.stages[2] +print rfModel # summary only +{% endhighlight %} +
+
+ +#### Example: Regression + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use a feature transformer to index categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. + +
+
+ +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.RandomForestRegressor) for more details. + +{% highlight scala %} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.regression.RandomForestRegressor +import org.apache.spark.ml.regression.RandomForestRegressionModel +import org.apache.spark.ml.feature.VectorIndexer +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.mllib.util.MLUtils + +// Load and parse the data file, converting it to a DataFrame. +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +// Automatically identify categorical features, and index them. +// Set maxCategories so features with > 4 distinct values are treated as continuous. +val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + +// Split the data into training and test sets (30% held out for testing) +val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + +// Train a RandomForest model. +val rf = new RandomForestRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures") + +// Chain indexer and forest in a Pipeline +val pipeline = new Pipeline() + .setStages(Array(featureIndexer, rf)) + +// Train model. This also runs the indexer. +val model = pipeline.fit(trainingData) + +// Make predictions. +val predictions = model.transform(testData) + +// Select example rows to display. +predictions.select("prediction", "label", "features").show(5) + +// Select (prediction, true label) and compute test error +val evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse") +val rmse = evaluator.evaluate(predictions) +println("Root Mean Squared Error (RMSE) on test data = " + rmse) + +val rfModel = model.stages(1).asInstanceOf[RandomForestRegressionModel] +println("Learned regression forest model:\n" + rfModel.toDebugString) +{% endhighlight %} +
+ +
+ +Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/RandomForestRegressor.html) for more details. + +{% highlight java %} +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.ml.regression.RandomForestRegressionModel; +import org.apache.spark.ml.regression.RandomForestRegressor; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.DataFrame; + +// Load and parse the data file, converting it to a DataFrame. +RDD rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt"); +DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class); + +// Automatically identify categorical features, and index them. +// Set maxCategories so features with > 4 distinct values are treated as continuous. +VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data); + +// Split the data into training and test sets (30% held out for testing) +DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); +DataFrame trainingData = splits[0]; +DataFrame testData = splits[1]; + +// Train a RandomForest model. +RandomForestRegressor rf = new RandomForestRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures"); + +// Chain indexer and forest in a Pipeline +Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {featureIndexer, rf}); + +// Train model. This also runs the indexer. +PipelineModel model = pipeline.fit(trainingData); + +// Make predictions. +DataFrame predictions = model.transform(testData); + +// Select example rows to display. +predictions.select("prediction", "label", "features").show(5); + +// Select (prediction, true label) and compute test error +RegressionEvaluator evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse"); +double rmse = evaluator.evaluate(predictions); +System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); + +RandomForestRegressionModel rfModel = + (RandomForestRegressionModel)(model.stages()[1]); +System.out.println("Learned regression forest model:\n" + rfModel.toDebugString()); +{% endhighlight %} +
+ +
+ +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.RandomForestRegressor) for more details. + +{% highlight python %} +from pyspark.ml import Pipeline +from pyspark.ml.regression import RandomForestRegressor +from pyspark.ml.feature import VectorIndexer +from pyspark.ml.evaluation import RegressionEvaluator +from pyspark.mllib.util import MLUtils + +# Load and parse the data file, converting it to a DataFrame. +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +# Automatically identify categorical features, and index them. +# Set maxCategories so features with > 4 distinct values are treated as continuous. +featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + +# Split the data into training and test sets (30% held out for testing) +(trainingData, testData) = data.randomSplit([0.7, 0.3]) + +# Train a RandomForest model. +rf = RandomForestRegressor(featuresCol="indexedFeatures") + +# Chain indexer and forest in a Pipeline +pipeline = Pipeline(stages=[featureIndexer, rf]) + +# Train model. This also runs the indexer. +model = pipeline.fit(trainingData) + +# Make predictions. +predictions = model.transform(testData) + +# Select example rows to display. +predictions.select("prediction", "label", "features").show(5) + +# Select (prediction, true label) and compute test error +evaluator = RegressionEvaluator( + labelCol="label", predictionCol="prediction", metricName="rmse") +rmse = evaluator.evaluate(predictions) +print "Root Mean Squared Error (RMSE) on test data = %g" % rmse + +rfModel = model.stages[1] +print rfModel # summary only +{% endhighlight %} +
+
+ +### Gradient-Boosted Trees (GBTs) + +[Gradient-Boosted Trees (GBTs)](http://en.wikipedia.org/wiki/Gradient_boosting) +are ensembles of [decision trees](ml-decision-tree.html). +GBTs iteratively train decision trees in order to minimize a loss function. +MLlib supports GBTs for binary classification and for regression, +using both continuous and categorical features. + +This section gives examples of using GBTs with the Pipelines API. +For more information on the algorithm, please see the [main MLlib docs on GBTs](mllib-ensembles.html). + +#### Inputs and Outputs + +We list the input and output (prediction) column types here. +All output columns are optional; to exclude an output column, set its corresponding Param to an empty string. + +##### Input Columns + + + + + + + + + + + + + + + + + + + + + + + + +
Param nameType(s)DefaultDescription
labelColDouble"label"Label to predict
featuresColVector"features"Feature vector
+ +Note that `GBTClassifier` currently only supports binary labels. + +##### Output Columns (Predictions) + + + + + + + + + + + + + + + + + + + + +
Param nameType(s)DefaultDescriptionNotes
predictionColDouble"prediction"Predicted label
+ +In the future, `GBTClassifier` will also output columns for `rawPrediction` and `probability`, just as `RandomForestClassifier` does. + +#### Example: Classification + +The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. +We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. + +
+
+ +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.GBTClassifier) for more details. + +{% highlight scala %} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.GBTClassifier +import org.apache.spark.ml.classification.GBTClassificationModel +import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +import org.apache.spark.mllib.util.MLUtils + +// Load and parse the data file, converting it to a DataFrame. +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +// Index labels, adding metadata to the label column. +// Fit on whole dataset to include all labels in index. +val labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data) +// Automatically identify categorical features, and index them. +// Set maxCategories so features with > 4 distinct values are treated as continuous. +val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + +// Split the data into training and test sets (30% held out for testing) +val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + +// Train a GBT model. +val gbt = new GBTClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures") + .setMaxIter(10) + +// Convert indexed labels back to original labels. +val labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels) + +// Chain indexers and GBT in a Pipeline +val pipeline = new Pipeline() + .setStages(Array(labelIndexer, featureIndexer, gbt, labelConverter)) + +// Train model. This also runs the indexers. +val model = pipeline.fit(trainingData) + +// Make predictions. +val predictions = model.transform(testData) + +// Select example rows to display. +predictions.select("predictedLabel", "label", "features").show(5) + +// Select (prediction, true label) and compute test error +val evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision") +val accuracy = evaluator.evaluate(predictions) +println("Test Error = " + (1.0 - accuracy)) + +val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel] +println("Learned classification GBT model:\n" + gbtModel.toDebugString) +{% endhighlight %} +
+ +
+ +Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/GBTClassifier.html) for more details. + +{% highlight java %} +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.GBTClassifier; +import org.apache.spark.ml.classification.GBTClassificationModel; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; +import org.apache.spark.ml.feature.*; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.DataFrame; + +// Load and parse the data file, converting it to a DataFrame. +RDD rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt"); +DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class); + +// Index labels, adding metadata to the label column. +// Fit on whole dataset to include all labels in index. +StringIndexerModel labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data); +// Automatically identify categorical features, and index them. +// Set maxCategories so features with > 4 distinct values are treated as continuous. +VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data); + +// Split the data into training and test sets (30% held out for testing) +DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); +DataFrame trainingData = splits[0]; +DataFrame testData = splits[1]; + +// Train a GBT model. +GBTClassifier gbt = new GBTClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures") + .setMaxIter(10); + +// Convert indexed labels back to original labels. +IndexToString labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels()); + +// Chain indexers and GBT in a Pipeline +Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {labelIndexer, featureIndexer, gbt, labelConverter}); + +// Train model. This also runs the indexers. +PipelineModel model = pipeline.fit(trainingData); + +// Make predictions. +DataFrame predictions = model.transform(testData); + +// Select example rows to display. +predictions.select("predictedLabel", "label", "features").show(5); + +// Select (prediction, true label) and compute test error +MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision"); +double accuracy = evaluator.evaluate(predictions); +System.out.println("Test Error = " + (1.0 - accuracy)); + +GBTClassificationModel gbtModel = + (GBTClassificationModel)(model.stages()[2]); +System.out.println("Learned classification GBT model:\n" + gbtModel.toDebugString()); +{% endhighlight %} +
+ +
+ +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.GBTClassifier) for more details. + +{% highlight python %} +from pyspark.ml import Pipeline +from pyspark.ml.classification import GBTClassifier +from pyspark.ml.feature import StringIndexer, VectorIndexer +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +from pyspark.mllib.util import MLUtils + +# Load and parse the data file, converting it to a DataFrame. +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +# Index labels, adding metadata to the label column. +# Fit on whole dataset to include all labels in index. +labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) +# Automatically identify categorical features, and index them. +# Set maxCategories so features with > 4 distinct values are treated as continuous. +featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + +# Split the data into training and test sets (30% held out for testing) +(trainingData, testData) = data.randomSplit([0.7, 0.3]) + +# Train a GBT model. +gbt = GBTClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures", maxIter=10) + +# Chain indexers and GBT in a Pipeline +pipeline = Pipeline(stages=[labelIndexer, featureIndexer, gbt]) + +# Train model. This also runs the indexers. +model = pipeline.fit(trainingData) + +# Make predictions. +predictions = model.transform(testData) + +# Select example rows to display. +predictions.select("prediction", "indexedLabel", "features").show(5) + +# Select (prediction, true label) and compute test error +evaluator = MulticlassClassificationEvaluator( + labelCol="indexedLabel", predictionCol="prediction", metricName="precision") +accuracy = evaluator.evaluate(predictions) +print "Test Error = %g" % (1.0 - accuracy) + +gbtModel = model.stages[2] +print gbtModel # summary only +{% endhighlight %} +
+
+ +#### Example: Regression + +Note: For this example dataset, `GBTRegressor` actually only needs 1 iteration, but that will not +be true in general. + +
+
+ +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.GBTRegressor) for more details. + +{% highlight scala %} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.regression.GBTRegressor +import org.apache.spark.ml.regression.GBTRegressionModel +import org.apache.spark.ml.feature.VectorIndexer +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.mllib.util.MLUtils + +// Load and parse the data file, converting it to a DataFrame. +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +// Automatically identify categorical features, and index them. +// Set maxCategories so features with > 4 distinct values are treated as continuous. +val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + +// Split the data into training and test sets (30% held out for testing) +val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + +// Train a GBT model. +val gbt = new GBTRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures") + .setMaxIter(10) + +// Chain indexer and GBT in a Pipeline +val pipeline = new Pipeline() + .setStages(Array(featureIndexer, gbt)) + +// Train model. This also runs the indexer. +val model = pipeline.fit(trainingData) + +// Make predictions. +val predictions = model.transform(testData) + +// Select example rows to display. +predictions.select("prediction", "label", "features").show(5) + +// Select (prediction, true label) and compute test error +val evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse") +val rmse = evaluator.evaluate(predictions) +println("Root Mean Squared Error (RMSE) on test data = " + rmse) + +val gbtModel = model.stages(1).asInstanceOf[GBTRegressionModel] +println("Learned regression GBT model:\n" + gbtModel.toDebugString) +{% endhighlight %} +
+ +
+ +Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/GBTRegressor.html) for more details. + +{% highlight java %} +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.ml.regression.GBTRegressionModel; +import org.apache.spark.ml.regression.GBTRegressor; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.DataFrame; + +// Load and parse the data file, converting it to a DataFrame. +RDD rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt"); +DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class); + +// Automatically identify categorical features, and index them. +// Set maxCategories so features with > 4 distinct values are treated as continuous. +VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data); + +// Split the data into training and test sets (30% held out for testing) +DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); +DataFrame trainingData = splits[0]; +DataFrame testData = splits[1]; + +// Train a GBT model. +GBTRegressor gbt = new GBTRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures") + .setMaxIter(10); + +// Chain indexer and GBT in a Pipeline +Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {featureIndexer, gbt}); + +// Train model. This also runs the indexer. +PipelineModel model = pipeline.fit(trainingData); + +// Make predictions. +DataFrame predictions = model.transform(testData); + +// Select example rows to display. +predictions.select("prediction", "label", "features").show(5); + +// Select (prediction, true label) and compute test error +RegressionEvaluator evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse"); +double rmse = evaluator.evaluate(predictions); +System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); + +GBTRegressionModel gbtModel = + (GBTRegressionModel)(model.stages()[1]); +System.out.println("Learned regression GBT model:\n" + gbtModel.toDebugString()); +{% endhighlight %} +
+ +
+ +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.GBTRegressor) for more details. + +{% highlight python %} +from pyspark.ml import Pipeline +from pyspark.ml.regression import GBTRegressor +from pyspark.ml.feature import VectorIndexer +from pyspark.ml.evaluation import RegressionEvaluator +from pyspark.mllib.util import MLUtils + +# Load and parse the data file, converting it to a DataFrame. +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +# Automatically identify categorical features, and index them. +# Set maxCategories so features with > 4 distinct values are treated as continuous. +featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + +# Split the data into training and test sets (30% held out for testing) +(trainingData, testData) = data.randomSplit([0.7, 0.3]) + +# Train a GBT model. +gbt = GBTRegressor(featuresCol="indexedFeatures", maxIter=10) + +# Chain indexer and GBT in a Pipeline +pipeline = Pipeline(stages=[featureIndexer, gbt]) + +# Train model. This also runs the indexer. +model = pipeline.fit(trainingData) + +# Make predictions. +predictions = model.transform(testData) + +# Select example rows to display. +predictions.select("prediction", "label", "features").show(5) + +# Select (prediction, true label) and compute test error +evaluator = RegressionEvaluator( + labelCol="label", predictionCol="prediction", metricName="rmse") +rmse = evaluator.evaluate(predictions) +print "Root Mean Squared Error (RMSE) on test data = %g" % rmse + +gbtModel = model.stages[1] +print gbtModel # summary only +{% endhighlight %} +
+
+ + +## One-vs-Rest (a.k.a. One-vs-All) + +[OneVsRest](http://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest) is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. It is also known as "One-vs-All." `OneVsRest` is implemented as an `Estimator`. For the base classifier it takes instances of `Classifier` and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. @@ -28,6 +964,9 @@ The example below demonstrates how to load the
+ +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classifier.OneVsRest) for more details. + {% highlight scala %} import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest} import org.apache.spark.mllib.evaluation.MulticlassMetrics @@ -64,9 +1003,12 @@ println("label\tfpr\n") } {% endhighlight %}
+
-{% highlight java %} +Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/OneVsRest.html) for more details. + +{% highlight java %} import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.classification.LogisticRegression; @@ -88,7 +1030,7 @@ RDD data = MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_multiclass_classification_data.txt"); DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class); -DataFrame[] splits = dataFrame.randomSplit(new double[]{0.7, 0.3}, 12345); +DataFrame[] splits = dataFrame.randomSplit(new double[] {0.7, 0.3}, 12345); DataFrame train = splits[0]; DataFrame test = splits[1]; From d7b4c095271c36fcc7f9ded267ecf5ec66fac803 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 24 Aug 2015 16:17:45 -0700 Subject: [PATCH 052/802] [SPARK-10190] Fix NPE in CatalystTypeConverters Decimal toScala converter This adds a missing null check to the Decimal `toScala` converter in `CatalystTypeConverters`, fixing an NPE. Author: Josh Rosen Closes #8401 from JoshRosen/SPARK-10190. --- .../apache/spark/sql/catalyst/CatalystTypeConverters.scala | 5 ++++- .../spark/sql/catalyst/CatalystTypeConvertersSuite.scala | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 8d0c64eae4774..966623ed017ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -329,7 +329,10 @@ object CatalystTypeConverters { null } } - override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal + override def toScala(catalystValue: Decimal): JavaBigDecimal = { + if (catalystValue == null) null + else catalystValue.toJavaBigDecimal + } override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal = row.getDecimal(column, dataType.precision, dataType.scale).toJavaBigDecimal } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala index df0f04563edcf..03bb102c67fe7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -32,7 +32,9 @@ class CatalystTypeConvertersSuite extends SparkFunSuite { IntegerType, LongType, FloatType, - DoubleType) + DoubleType, + DecimalType.SYSTEM_DEFAULT, + DecimalType.USER_DEFAULT) test("null handling in rows") { val schema = StructType(simpleTypes.map(t => StructField(t.getClass.getName, t))) From 2bf338c626e9d97ccc033cfadae8b36a82c66fd1 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 24 Aug 2015 18:10:51 -0700 Subject: [PATCH 053/802] [SPARK-10165] [SQL] Await child resolution in ResolveFunctions Currently, we eagerly attempt to resolve functions, even before their children are resolved. However, this is not valid in cases where we need to know the types of the input arguments (i.e. when resolving Hive UDFs). As a fix, this PR delays function resolution until the functions children are resolved. This change also necessitates a change to the way we resolve aggregate expressions that are not in aggregate operators (e.g., in `HAVING` or `ORDER BY` clauses). Specifically, we can't assume that these misplaced functions will be resolved, allowing us to differentiate aggregate functions from normal functions. To compensate for this change we now attempt to resolve these unresolved expressions in the context of the aggregate operator, before checking to see if any aggregate expressions are present. Author: Michael Armbrust Closes #8371 from marmbrus/hiveUDFResolution. --- .../sql/catalyst/analysis/Analyzer.scala | 116 +++++++++++------- .../sql/hive/execution/HiveUDFSuite.scala | 5 + 2 files changed, 77 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d0eb9c2c90bdf..1a5de15c61f86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -78,7 +78,7 @@ class Analyzer( ResolveAliases :: ExtractWindowExpressions :: GlobalAggregates :: - UnresolvedHavingClauseAttributes :: + ResolveAggregateFunctions :: HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, @@ -452,37 +452,6 @@ class Analyzer( logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}") s // Nothing we can do here. Return original plan. } - case s @ Sort(ordering, global, a @ Aggregate(grouping, aggs, child)) - if !s.resolved && a.resolved => - // A small hack to create an object that will allow us to resolve any references that - // refer to named expressions that are present in the grouping expressions. - val groupingRelation = LocalRelation( - grouping.collect { case ne: NamedExpression => ne.toAttribute } - ) - - // Find sort attributes that are projected away so we can temporarily add them back in. - val (newOrdering, missingAttr) = resolveAndFindMissing(ordering, a, groupingRelation) - - // Find aggregate expressions and evaluate them early, since they can't be evaluated in a - // Sort. - val (withAggsRemoved, aliasedAggregateList) = newOrdering.map { - case aggOrdering if aggOrdering.collect { case a: AggregateExpression => a }.nonEmpty => - val aliased = Alias(aggOrdering.child, "_aggOrdering")() - (aggOrdering.copy(child = aliased.toAttribute), Some(aliased)) - - case other => (other, None) - }.unzip - - val missing = missingAttr ++ aliasedAggregateList.flatten - - if (missing.nonEmpty) { - // Add missing grouping exprs and then project them away after the sort. - Project(a.output, - Sort(withAggsRemoved, global, - Aggregate(grouping, aggs ++ missing, child))) - } else { - s // Nothing we can do here. Return original plan. - } } /** @@ -515,6 +484,7 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case q: LogicalPlan => q transformExpressions { + case u if !u.childrenResolved => u // Skip until children are resolved. case u @ UnresolvedFunction(name, children, isDistinct) => withPosition(u) { registry.lookupFunction(name, children) match { @@ -559,21 +529,79 @@ class Analyzer( } /** - * This rule finds expressions in HAVING clause filters that depend on - * unresolved attributes. It pushes these expressions down to the underlying - * aggregates and then projects them away above the filter. + * This rule finds aggregate expressions that are not in an aggregate operator. For example, + * those in a HAVING clause or ORDER BY clause. These expressions are pushed down to the + * underlying aggregate operator and then projected away after the original operator. */ - object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] { + object ResolveAggregateFunctions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _)) - if aggregate.resolved && containsAggregate(havingCondition) => - - val evaluatedCondition = Alias(havingCondition, "havingCondition")() - val aggExprsWithHaving = evaluatedCondition +: originalAggExprs + case filter @ Filter(havingCondition, + aggregate @ Aggregate(grouping, originalAggExprs, child)) + if aggregate.resolved && !filter.resolved => + + // Try resolving the condition of the filter as though it is in the aggregate clause + val aggregatedCondition = + Aggregate(grouping, Alias(havingCondition, "havingCondition")() :: Nil, child) + val resolvedOperator = execute(aggregatedCondition) + def resolvedAggregateFilter = + resolvedOperator + .asInstanceOf[Aggregate] + .aggregateExpressions.head + + // If resolution was successful and we see the filter has an aggregate in it, add it to + // the original aggregate operator. + if (resolvedOperator.resolved && containsAggregate(resolvedAggregateFilter)) { + val aggExprsWithHaving = resolvedAggregateFilter +: originalAggExprs + + Project(aggregate.output, + Filter(resolvedAggregateFilter.toAttribute, + aggregate.copy(aggregateExpressions = aggExprsWithHaving))) + } else { + filter + } - Project(aggregate.output, - Filter(evaluatedCondition.toAttribute, - aggregate.copy(aggregateExpressions = aggExprsWithHaving))) + case sort @ Sort(sortOrder, global, + aggregate @ Aggregate(grouping, originalAggExprs, child)) + if aggregate.resolved && !sort.resolved => + + // Try resolving the ordering as though it is in the aggregate clause. + try { + val aliasedOrder = sortOrder.map(o => Alias(o.child, "aggOrder")()) + val aggregatedOrdering = Aggregate(grouping, aliasedOrder, child) + val resolvedOperator: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] + def resolvedAggregateOrdering = resolvedOperator.aggregateExpressions + + // Expressions that have an aggregate can be pushed down. + val needsAggregate = resolvedAggregateOrdering.exists(containsAggregate) + + // Attribute references, that are missing from the order but are present in the grouping + // expressions can also be pushed down. + val requiredAttributes = resolvedAggregateOrdering.map(_.references).reduce(_ ++ _) + val missingAttributes = requiredAttributes -- aggregate.outputSet + val validPushdownAttributes = + missingAttributes.filter(a => grouping.exists(a.semanticEquals)) + + // If resolution was successful and we see the ordering either has an aggregate in it or + // it is missing something that is projected away by the aggregate, add the ordering + // the original aggregate operator. + if (resolvedOperator.resolved && (needsAggregate || validPushdownAttributes.nonEmpty)) { + val evaluatedOrderings: Seq[SortOrder] = sortOrder.zip(resolvedAggregateOrdering).map { + case (order, evaluated) => order.copy(child = evaluated.toAttribute) + } + val aggExprsWithOrdering: Seq[NamedExpression] = + resolvedAggregateOrdering ++ originalAggExprs + + Project(aggregate.output, + Sort(evaluatedOrderings, global, + aggregate.copy(aggregateExpressions = aggExprsWithOrdering))) + } else { + sort + } + } catch { + // Attempting to resolve in the aggregate can result in ambiguity. When this happens, + // just return the original plan. + case ae: AnalysisException => sort + } } protected def containsAggregate(condition: Expression): Boolean = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 10f2902e5eef0..b03a35132325d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -276,6 +276,11 @@ class HiveUDFSuite extends QueryTest { checkAnswer( sql("SELECT testStringStringUDF(\"hello\", s) FROM stringTable"), Seq(Row("hello world"), Row("hello goodbye"))) + + checkAnswer( + sql("SELECT testStringStringUDF(\"\", testStringStringUDF(\"hello\", s)) FROM stringTable"), + Seq(Row(" hello world"), Row(" hello goodbye"))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUDF") TestHive.reset() From 6511bf559b736d8e23ae398901c8d78938e66869 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Mon, 24 Aug 2015 18:17:51 -0700 Subject: [PATCH 054/802] [SPARK-10118] [SPARKR] [DOCS] Improve SparkR API docs for 1.5 release cc: shivaram ## Summary - Modify `tdname` of expression functions. i.e. `ascii`: `rdname functions` => `rdname ascii` - Replace the dynamical function definitions to the static ones because of thir documentations. ## Generated PDF File https://drive.google.com/file/d/0B9biIZIU47lLX2t6ZjRoRnBTSEU/view?usp=sharing ## JIRA [[SPARK-10118] Improve SparkR API docs for 1.5 release - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-10118) Author: Yu ISHIKAWA Author: Yuu ISHIKAWA Closes #8386 from yu-iskw/SPARK-10118. --- R/create-docs.sh | 2 +- R/pkg/R/column.R | 5 +- R/pkg/R/functions.R | 1603 +++++++++++++++++++++++++++++++++++++++---- R/pkg/R/generics.R | 214 +++--- 4 files changed, 1596 insertions(+), 228 deletions(-) diff --git a/R/create-docs.sh b/R/create-docs.sh index 6a4687b06ecb9..d2ae160b50021 100755 --- a/R/create-docs.sh +++ b/R/create-docs.sh @@ -39,7 +39,7 @@ pushd $FWDIR mkdir -p pkg/html pushd pkg/html -Rscript -e 'library(SparkR, lib.loc="../../lib"); library(knitr); knit_rd("SparkR")' +Rscript -e 'libDir <- "../../lib"; library(SparkR, lib.loc=libDir); library(knitr); knit_rd("SparkR", links = tools::findHTMLlinks(paste(libDir, "SparkR", sep="/")))' popd diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 5a07ebd308296..a1f50c383367c 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -169,8 +169,7 @@ setMethod("between", signature(x = "Column"), #' #' @rdname column #' -#' @examples -#' \dontrun{ +#' @examples \dontrun{ #' cast(df$age, "string") #' cast(df$name, list(type="array", elementType="byte", containsNull = TRUE)) #' } @@ -192,7 +191,7 @@ setMethod("cast", #' #' @rdname column #' @return a matched values as a result of comparing with given values. -#' \dontrun{ +#' @examples \dontrun{ #' filter(df, "age in (10, 30)") #' where(df, df$age %in% c(10, 30)) #' } diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index b5879bd9ad553..d848730e70433 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -18,69 +18,1298 @@ #' @include generics.R column.R NULL -#' @title S4 expression functions for DataFrame column(s) -#' @description These are expression functions on DataFrame columns - -functions1 <- c( - "abs", "acos", "approxCountDistinct", "ascii", "asin", "atan", - "avg", "base64", "bin", "bitwiseNOT", "cbrt", "ceil", "cos", "cosh", "count", - "crc32", "dayofmonth", "dayofyear", "exp", "explode", "expm1", "factorial", - "first", "floor", "hex", "hour", "initcap", "isNaN", "last", "last_day", - "length", "log", "log10", "log1p", "log2", "lower", "ltrim", "max", "md5", - "mean", "min", "minute", "month", "negate", "quarter", "reverse", - "rint", "round", "rtrim", "second", "sha1", "signum", "sin", "sinh", "size", - "soundex", "sqrt", "sum", "sumDistinct", "tan", "tanh", "toDegrees", - "toRadians", "to_date", "trim", "unbase64", "unhex", "upper", "weekofyear", - "year") -functions2 <- c( - "atan2", "datediff", "hypot", "levenshtein", "months_between", "nanvl", "pmod") - -createFunction1 <- function(name) { - setMethod(name, - signature(x = "Column"), - function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc) - column(jc) - }) -} - -createFunction2 <- function(name) { - setMethod(name, - signature(y = "Column"), - function(y, x) { - if (class(x) == "Column") { - x <- x@jc - } - jc <- callJStatic("org.apache.spark.sql.functions", name, y@jc, x) - column(jc) - }) -} +#' Creates a \code{Column} of literal value. +#' +#' The passed in object is returned directly if it is already a \linkS4class{Column}. +#' If the object is a Scala Symbol, it is converted into a \linkS4class{Column} also. +#' Otherwise, a new \linkS4class{Column} is created to represent the literal value. +#' +#' @family normal_funcs +#' @rdname lit +#' @name lit +#' @export +setMethod("lit", signature("ANY"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", + "lit", + ifelse(class(x) == "Column", x@jc, x)) + column(jc) + }) + +#' abs +#' +#' Computes the absolute value. +#' +#' @rdname abs +#' @name abs +#' @family normal_funcs +#' @export +#' @examples \dontrun{abs(df$c)} +setMethod("abs", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "abs", x@jc) + column(jc) + }) + +#' acos +#' +#' Computes the cosine inverse of the given value; the returned angle is in the range +#' 0.0 through pi. +#' +#' @rdname acos +#' @name acos +#' @family math_funcs +#' @export +#' @examples \dontrun{acos(df$c)} +setMethod("acos", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "acos", x@jc) + column(jc) + }) + +#' approxCountDistinct +#' +#' Aggregate function: returns the approximate number of distinct items in a group. +#' +#' @rdname approxCountDistinct +#' @name approxCountDistinct +#' @family agg_funcs +#' @export +#' @examples \dontrun{approxCountDistinct(df$c)} +setMethod("approxCountDistinct", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc) + column(jc) + }) + +#' ascii +#' +#' Computes the numeric value of the first character of the string column, and returns the +#' result as a int column. +#' +#' @rdname ascii +#' @name ascii +#' @family string_funcs +#' @export +#' @examples \dontrun{\dontrun{ascii(df$c)}} +setMethod("ascii", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "ascii", x@jc) + column(jc) + }) + +#' asin +#' +#' Computes the sine inverse of the given value; the returned angle is in the range +#' -pi/2 through pi/2. +#' +#' @rdname asin +#' @name asin +#' @family math_funcs +#' @export +#' @examples \dontrun{asin(df$c)} +setMethod("asin", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "asin", x@jc) + column(jc) + }) + +#' atan +#' +#' Computes the tangent inverse of the given value. +#' +#' @rdname atan +#' @name atan +#' @family math_funcs +#' @export +#' @examples \dontrun{atan(df$c)} +setMethod("atan", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "atan", x@jc) + column(jc) + }) + +#' avg +#' +#' Aggregate function: returns the average of the values in a group. +#' +#' @rdname avg +#' @name avg +#' @family agg_funcs +#' @export +#' @examples \dontrun{avg(df$c)} +setMethod("avg", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "avg", x@jc) + column(jc) + }) + +#' base64 +#' +#' Computes the BASE64 encoding of a binary column and returns it as a string column. +#' This is the reverse of unbase64. +#' +#' @rdname base64 +#' @name base64 +#' @family string_funcs +#' @export +#' @examples \dontrun{base64(df$c)} +setMethod("base64", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "base64", x@jc) + column(jc) + }) + +#' bin +#' +#' An expression that returns the string representation of the binary value of the given long +#' column. For example, bin("12") returns "1100". +#' +#' @rdname bin +#' @name bin +#' @family math_funcs +#' @export +#' @examples \dontrun{bin(df$c)} +setMethod("bin", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "bin", x@jc) + column(jc) + }) + +#' bitwiseNOT +#' +#' Computes bitwise NOT. +#' +#' @rdname bitwiseNOT +#' @name bitwiseNOT +#' @family normal_funcs +#' @export +#' @examples \dontrun{bitwiseNOT(df$c)} +setMethod("bitwiseNOT", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "bitwiseNOT", x@jc) + column(jc) + }) + +#' cbrt +#' +#' Computes the cube-root of the given value. +#' +#' @rdname cbrt +#' @name cbrt +#' @family math_funcs +#' @export +#' @examples \dontrun{cbrt(df$c)} +setMethod("cbrt", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "cbrt", x@jc) + column(jc) + }) + +#' ceil +#' +#' Computes the ceiling of the given value. +#' +#' @rdname ceil +#' @name ceil +#' @family math_funcs +#' @export +#' @examples \dontrun{ceil(df$c)} +setMethod("ceil", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "ceil", x@jc) + column(jc) + }) + +#' cos +#' +#' Computes the cosine of the given value. +#' +#' @rdname cos +#' @name cos +#' @family math_funcs +#' @export +#' @examples \dontrun{cos(df$c)} +setMethod("cos", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "cos", x@jc) + column(jc) + }) + +#' cosh +#' +#' Computes the hyperbolic cosine of the given value. +#' +#' @rdname cosh +#' @name cosh +#' @family math_funcs +#' @export +#' @examples \dontrun{cosh(df$c)} +setMethod("cosh", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "cosh", x@jc) + column(jc) + }) + +#' count +#' +#' Aggregate function: returns the number of items in a group. +#' +#' @rdname count +#' @name count +#' @family agg_funcs +#' @export +#' @examples \dontrun{count(df$c)} +setMethod("count", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "count", x@jc) + column(jc) + }) + +#' crc32 +#' +#' Calculates the cyclic redundancy check value (CRC32) of a binary column and +#' returns the value as a bigint. +#' +#' @rdname crc32 +#' @name crc32 +#' @family misc_funcs +#' @export +#' @examples \dontrun{crc32(df$c)} +setMethod("crc32", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "crc32", x@jc) + column(jc) + }) + +#' dayofmonth +#' +#' Extracts the day of the month as an integer from a given date/timestamp/string. +#' +#' @rdname dayofmonth +#' @name dayofmonth +#' @family datetime_funcs +#' @export +#' @examples \dontrun{dayofmonth(df$c)} +setMethod("dayofmonth", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "dayofmonth", x@jc) + column(jc) + }) + +#' dayofyear +#' +#' Extracts the day of the year as an integer from a given date/timestamp/string. +#' +#' @rdname dayofyear +#' @name dayofyear +#' @family datetime_funcs +#' @export +#' @examples \dontrun{dayofyear(df$c)} +setMethod("dayofyear", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "dayofyear", x@jc) + column(jc) + }) + +#' exp +#' +#' Computes the exponential of the given value. +#' +#' @rdname exp +#' @name exp +#' @family math_funcs +#' @export +#' @examples \dontrun{exp(df$c)} +setMethod("exp", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "exp", x@jc) + column(jc) + }) + +#' explode +#' +#' Creates a new row for each element in the given array or map column. +#' +#' @rdname explode +#' @name explode +#' @family collection_funcs +#' @export +#' @examples \dontrun{explode(df$c)} +setMethod("explode", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "explode", x@jc) + column(jc) + }) + +#' expm1 +#' +#' Computes the exponential of the given value minus one. +#' +#' @rdname expm1 +#' @name expm1 +#' @family math_funcs +#' @export +#' @examples \dontrun{expm1(df$c)} +setMethod("expm1", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "expm1", x@jc) + column(jc) + }) + +#' factorial +#' +#' Computes the factorial of the given value. +#' +#' @rdname factorial +#' @name factorial +#' @family math_funcs +#' @export +#' @examples \dontrun{factorial(df$c)} +setMethod("factorial", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "factorial", x@jc) + column(jc) + }) + +#' first +#' +#' Aggregate function: returns the first value in a group. +#' +#' @rdname first +#' @name first +#' @family agg_funcs +#' @export +#' @examples \dontrun{first(df$c)} +setMethod("first", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "first", x@jc) + column(jc) + }) + +#' floor +#' +#' Computes the floor of the given value. +#' +#' @rdname floor +#' @name floor +#' @family math_funcs +#' @export +#' @examples \dontrun{floor(df$c)} +setMethod("floor", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "floor", x@jc) + column(jc) + }) + +#' hex +#' +#' Computes hex value of the given column. +#' +#' @rdname hex +#' @name hex +#' @family math_funcs +#' @export +#' @examples \dontrun{hex(df$c)} +setMethod("hex", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "hex", x@jc) + column(jc) + }) + +#' hour +#' +#' Extracts the hours as an integer from a given date/timestamp/string. +#' +#' @rdname hour +#' @name hour +#' @family datetime_funcs +#' @export +#' @examples \dontrun{hour(df$c)} +setMethod("hour", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "hour", x@jc) + column(jc) + }) + +#' initcap +#' +#' Returns a new string column by converting the first letter of each word to uppercase. +#' Words are delimited by whitespace. +#' +#' For example, "hello world" will become "Hello World". +#' +#' @rdname initcap +#' @name initcap +#' @family string_funcs +#' @export +#' @examples \dontrun{initcap(df$c)} +setMethod("initcap", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "initcap", x@jc) + column(jc) + }) + +#' isNaN +#' +#' Return true iff the column is NaN. +#' +#' @rdname isNaN +#' @name isNaN +#' @family normal_funcs +#' @export +#' @examples \dontrun{isNaN(df$c)} +setMethod("isNaN", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "isNaN", x@jc) + column(jc) + }) + +#' last +#' +#' Aggregate function: returns the last value in a group. +#' +#' @rdname last +#' @name last +#' @family agg_funcs +#' @export +#' @examples \dontrun{last(df$c)} +setMethod("last", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "last", x@jc) + column(jc) + }) + +#' last_day +#' +#' Given a date column, returns the last day of the month which the given date belongs to. +#' For example, input "2015-07-27" returns "2015-07-31" since July 31 is the last day of the +#' month in July 2015. +#' +#' @rdname last_day +#' @name last_day +#' @family datetime_funcs +#' @export +#' @examples \dontrun{last_day(df$c)} +setMethod("last_day", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "last_day", x@jc) + column(jc) + }) + +#' length +#' +#' Computes the length of a given string or binary column. +#' +#' @rdname length +#' @name length +#' @family string_funcs +#' @export +#' @examples \dontrun{length(df$c)} +setMethod("length", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "length", x@jc) + column(jc) + }) + +#' log +#' +#' Computes the natural logarithm of the given value. +#' +#' @rdname log +#' @name log +#' @family math_funcs +#' @export +#' @examples \dontrun{log(df$c)} +setMethod("log", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "log", x@jc) + column(jc) + }) + +#' log10 +#' +#' Computes the logarithm of the given value in base 10. +#' +#' @rdname log10 +#' @name log10 +#' @family math_funcs +#' @export +#' @examples \dontrun{log10(df$c)} +setMethod("log10", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "log10", x@jc) + column(jc) + }) + +#' log1p +#' +#' Computes the natural logarithm of the given value plus one. +#' +#' @rdname log1p +#' @name log1p +#' @family math_funcs +#' @export +#' @examples \dontrun{log1p(df$c)} +setMethod("log1p", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "log1p", x@jc) + column(jc) + }) + +#' log2 +#' +#' Computes the logarithm of the given column in base 2. +#' +#' @rdname log2 +#' @name log2 +#' @family math_funcs +#' @export +#' @examples \dontrun{log2(df$c)} +setMethod("log2", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "log2", x@jc) + column(jc) + }) + +#' lower +#' +#' Converts a string column to lower case. +#' +#' @rdname lower +#' @name lower +#' @family string_funcs +#' @export +#' @examples \dontrun{lower(df$c)} +setMethod("lower", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "lower", x@jc) + column(jc) + }) + +#' ltrim +#' +#' Trim the spaces from left end for the specified string value. +#' +#' @rdname ltrim +#' @name ltrim +#' @family string_funcs +#' @export +#' @examples \dontrun{ltrim(df$c)} +setMethod("ltrim", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "ltrim", x@jc) + column(jc) + }) + +#' max +#' +#' Aggregate function: returns the maximum value of the expression in a group. +#' +#' @rdname max +#' @name max +#' @family agg_funcs +#' @export +#' @examples \dontrun{max(df$c)} +setMethod("max", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "max", x@jc) + column(jc) + }) + +#' md5 +#' +#' Calculates the MD5 digest of a binary column and returns the value +#' as a 32 character hex string. +#' +#' @rdname md5 +#' @name md5 +#' @family misc_funcs +#' @export +#' @examples \dontrun{md5(df$c)} +setMethod("md5", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "md5", x@jc) + column(jc) + }) + +#' mean +#' +#' Aggregate function: returns the average of the values in a group. +#' Alias for avg. +#' +#' @rdname mean +#' @name mean +#' @family agg_funcs +#' @export +#' @examples \dontrun{mean(df$c)} +setMethod("mean", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "mean", x@jc) + column(jc) + }) + +#' min +#' +#' Aggregate function: returns the minimum value of the expression in a group. +#' +#' @rdname min +#' @name min +#' @family agg_funcs +#' @export +#' @examples \dontrun{min(df$c)} +setMethod("min", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "min", x@jc) + column(jc) + }) + +#' minute +#' +#' Extracts the minutes as an integer from a given date/timestamp/string. +#' +#' @rdname minute +#' @name minute +#' @family datetime_funcs +#' @export +#' @examples \dontrun{minute(df$c)} +setMethod("minute", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "minute", x@jc) + column(jc) + }) + +#' month +#' +#' Extracts the month as an integer from a given date/timestamp/string. +#' +#' @rdname month +#' @name month +#' @family datetime_funcs +#' @export +#' @examples \dontrun{month(df$c)} +setMethod("month", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "month", x@jc) + column(jc) + }) + +#' negate +#' +#' Unary minus, i.e. negate the expression. +#' +#' @rdname negate +#' @name negate +#' @family normal_funcs +#' @export +#' @examples \dontrun{negate(df$c)} +setMethod("negate", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "negate", x@jc) + column(jc) + }) -createFunctions <- function() { - for (name in functions1) { - createFunction1(name) - } - for (name in functions2) { - createFunction2(name) - } -} +#' quarter +#' +#' Extracts the quarter as an integer from a given date/timestamp/string. +#' +#' @rdname quarter +#' @name quarter +#' @family datetime_funcs +#' @export +#' @examples \dontrun{quarter(df$c)} +setMethod("quarter", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "quarter", x@jc) + column(jc) + }) -createFunctions() +#' reverse +#' +#' Reverses the string column and returns it as a new string column. +#' +#' @rdname reverse +#' @name reverse +#' @family string_funcs +#' @export +#' @examples \dontrun{reverse(df$c)} +setMethod("reverse", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "reverse", x@jc) + column(jc) + }) -#' @rdname functions -#' @return Creates a Column class of literal value. -setMethod("lit", signature("ANY"), +#' rint +#' +#' Returns the double value that is closest in value to the argument and +#' is equal to a mathematical integer. +#' +#' @rdname rint +#' @name rint +#' @family math_funcs +#' @export +#' @examples \dontrun{rint(df$c)} +setMethod("rint", + signature(x = "Column"), function(x) { - jc <- callJStatic("org.apache.spark.sql.functions", - "lit", - ifelse(class(x) == "Column", x@jc, x)) + jc <- callJStatic("org.apache.spark.sql.functions", "rint", x@jc) + column(jc) + }) + +#' round +#' +#' Returns the value of the column `e` rounded to 0 decimal places. +#' +#' @rdname round +#' @name round +#' @family math_funcs +#' @export +#' @examples \dontrun{round(df$c)} +setMethod("round", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "round", x@jc) + column(jc) + }) + +#' rtrim +#' +#' Trim the spaces from right end for the specified string value. +#' +#' @rdname rtrim +#' @name rtrim +#' @family string_funcs +#' @export +#' @examples \dontrun{rtrim(df$c)} +setMethod("rtrim", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "rtrim", x@jc) + column(jc) + }) + +#' second +#' +#' Extracts the seconds as an integer from a given date/timestamp/string. +#' +#' @rdname second +#' @name second +#' @family datetime_funcs +#' @export +#' @examples \dontrun{second(df$c)} +setMethod("second", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "second", x@jc) + column(jc) + }) + +#' sha1 +#' +#' Calculates the SHA-1 digest of a binary column and returns the value +#' as a 40 character hex string. +#' +#' @rdname sha1 +#' @name sha1 +#' @family misc_funcs +#' @export +#' @examples \dontrun{sha1(df$c)} +setMethod("sha1", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sha1", x@jc) + column(jc) + }) + +#' signum +#' +#' Computes the signum of the given value. +#' +#' @rdname signum +#' @name signum +#' @family math_funcs +#' @export +#' @examples \dontrun{signum(df$c)} +setMethod("signum", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "signum", x@jc) + column(jc) + }) + +#' sin +#' +#' Computes the sine of the given value. +#' +#' @rdname sin +#' @name sin +#' @family math_funcs +#' @export +#' @examples \dontrun{sin(df$c)} +setMethod("sin", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sin", x@jc) + column(jc) + }) + +#' sinh +#' +#' Computes the hyperbolic sine of the given value. +#' +#' @rdname sinh +#' @name sinh +#' @family math_funcs +#' @export +#' @examples \dontrun{sinh(df$c)} +setMethod("sinh", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sinh", x@jc) + column(jc) + }) + +#' size +#' +#' Returns length of array or map. +#' +#' @rdname size +#' @name size +#' @family collection_funcs +#' @export +#' @examples \dontrun{size(df$c)} +setMethod("size", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "size", x@jc) + column(jc) + }) + +#' soundex +#' +#' Return the soundex code for the specified expression. +#' +#' @rdname soundex +#' @name soundex +#' @family string_funcs +#' @export +#' @examples \dontrun{soundex(df$c)} +setMethod("soundex", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "soundex", x@jc) + column(jc) + }) + +#' sqrt +#' +#' Computes the square root of the specified float value. +#' +#' @rdname sqrt +#' @name sqrt +#' @family math_funcs +#' @export +#' @examples \dontrun{sqrt(df$c)} +setMethod("sqrt", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sqrt", x@jc) + column(jc) + }) + +#' sum +#' +#' Aggregate function: returns the sum of all values in the expression. +#' +#' @rdname sum +#' @name sum +#' @family agg_funcs +#' @export +#' @examples \dontrun{sum(df$c)} +setMethod("sum", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sum", x@jc) + column(jc) + }) + +#' sumDistinct +#' +#' Aggregate function: returns the sum of distinct values in the expression. +#' +#' @rdname sumDistinct +#' @name sumDistinct +#' @family agg_funcs +#' @export +#' @examples \dontrun{sumDistinct(df$c)} +setMethod("sumDistinct", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sumDistinct", x@jc) + column(jc) + }) + +#' tan +#' +#' Computes the tangent of the given value. +#' +#' @rdname tan +#' @name tan +#' @family math_funcs +#' @export +#' @examples \dontrun{tan(df$c)} +setMethod("tan", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "tan", x@jc) + column(jc) + }) + +#' tanh +#' +#' Computes the hyperbolic tangent of the given value. +#' +#' @rdname tanh +#' @name tanh +#' @family math_funcs +#' @export +#' @examples \dontrun{tanh(df$c)} +setMethod("tanh", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "tanh", x@jc) + column(jc) + }) + +#' toDegrees +#' +#' Converts an angle measured in radians to an approximately equivalent angle measured in degrees. +#' +#' @rdname toDegrees +#' @name toDegrees +#' @family math_funcs +#' @export +#' @examples \dontrun{toDegrees(df$c)} +setMethod("toDegrees", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "toDegrees", x@jc) + column(jc) + }) + +#' toRadians +#' +#' Converts an angle measured in degrees to an approximately equivalent angle measured in radians. +#' +#' @rdname toRadians +#' @name toRadians +#' @family math_funcs +#' @export +#' @examples \dontrun{toRadians(df$c)} +setMethod("toRadians", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "toRadians", x@jc) + column(jc) + }) + +#' to_date +#' +#' Converts the column into DateType. +#' +#' @rdname to_date +#' @name to_date +#' @family datetime_funcs +#' @export +#' @examples \dontrun{to_date(df$c)} +setMethod("to_date", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "to_date", x@jc) + column(jc) + }) + +#' trim +#' +#' Trim the spaces from both ends for the specified string column. +#' +#' @rdname trim +#' @name trim +#' @family string_funcs +#' @export +#' @examples \dontrun{trim(df$c)} +setMethod("trim", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "trim", x@jc) + column(jc) + }) + +#' unbase64 +#' +#' Decodes a BASE64 encoded string column and returns it as a binary column. +#' This is the reverse of base64. +#' +#' @rdname unbase64 +#' @name unbase64 +#' @family string_funcs +#' @export +#' @examples \dontrun{unbase64(df$c)} +setMethod("unbase64", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "unbase64", x@jc) + column(jc) + }) + +#' unhex +#' +#' Inverse of hex. Interprets each pair of characters as a hexadecimal number +#' and converts to the byte representation of number. +#' +#' @rdname unhex +#' @name unhex +#' @family math_funcs +#' @export +#' @examples \dontrun{unhex(df$c)} +setMethod("unhex", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "unhex", x@jc) + column(jc) + }) + +#' upper +#' +#' Converts a string column to upper case. +#' +#' @rdname upper +#' @name upper +#' @family string_funcs +#' @export +#' @examples \dontrun{upper(df$c)} +setMethod("upper", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "upper", x@jc) + column(jc) + }) + +#' weekofyear +#' +#' Extracts the week number as an integer from a given date/timestamp/string. +#' +#' @rdname weekofyear +#' @name weekofyear +#' @family datetime_funcs +#' @export +#' @examples \dontrun{weekofyear(df$c)} +setMethod("weekofyear", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "weekofyear", x@jc) + column(jc) + }) + +#' year +#' +#' Extracts the year as an integer from a given date/timestamp/string. +#' +#' @rdname year +#' @name year +#' @family datetime_funcs +#' @export +#' @examples \dontrun{year(df$c)} +setMethod("year", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "year", x@jc) column(jc) }) +#' atan2 +#' +#' Returns the angle theta from the conversion of rectangular coordinates (x, y) to +#' polar coordinates (r, theta). +#' +#' @rdname atan2 +#' @name atan2 +#' @family math_funcs +#' @export +#' @examples \dontrun{atan2(df$c, x)} +setMethod("atan2", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "atan2", y@jc, x) + column(jc) + }) + +#' datediff +#' +#' Returns the number of days from `start` to `end`. +#' +#' @rdname datediff +#' @name datediff +#' @family datetime_funcs +#' @export +#' @examples \dontrun{datediff(df$c, x)} +setMethod("datediff", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "datediff", y@jc, x) + column(jc) + }) + +#' hypot +#' +#' Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. +#' +#' @rdname hypot +#' @name hypot +#' @family math_funcs +#' @export +#' @examples \dontrun{hypot(df$c, x)} +setMethod("hypot", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "hypot", y@jc, x) + column(jc) + }) + +#' levenshtein +#' +#' Computes the Levenshtein distance of the two given string columns. +#' +#' @rdname levenshtein +#' @name levenshtein +#' @family string_funcs +#' @export +#' @examples \dontrun{levenshtein(df$c, x)} +setMethod("levenshtein", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "levenshtein", y@jc, x) + column(jc) + }) + +#' months_between +#' +#' Returns number of months between dates `date1` and `date2`. +#' +#' @rdname months_between +#' @name months_between +#' @family datetime_funcs +#' @export +#' @examples \dontrun{months_between(df$c, x)} +setMethod("months_between", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "months_between", y@jc, x) + column(jc) + }) + +#' nanvl +#' +#' Returns col1 if it is not NaN, or col2 if col1 is NaN. +#' hhBoth inputs should be floating point columns (DoubleType or FloatType). +#' +#' @rdname nanvl +#' @name nanvl +#' @family normal_funcs +#' @export +#' @examples \dontrun{nanvl(df$c, x)} +setMethod("nanvl", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "nanvl", y@jc, x) + column(jc) + }) + +#' pmod +#' +#' Returns the positive value of dividend mod divisor. +#' +#' @rdname pmod +#' @name pmod +#' @docType methods +#' @family math_funcs +#' @export +#' @examples \dontrun{pmod(df$c, x)} +setMethod("pmod", signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", "pmod", y@jc, x) + column(jc) + }) + + #' Approx Count Distinct #' -#' @rdname functions +#' @family agg_funcs +#' @rdname approxCountDistinct +#' @name approxCountDistinct #' @return the approximate number of distinct items in a group. +#' @export setMethod("approxCountDistinct", signature(x = "Column"), function(x, rsd = 0.95) { @@ -90,8 +1319,11 @@ setMethod("approxCountDistinct", #' Count Distinct #' -#' @rdname functions +#' @family agg_funcs +#' @rdname countDistinct +#' @name countDistinct #' @return the number of distinct items in a group. +#' @export setMethod("countDistinct", signature(x = "Column"), function(x, ...) { @@ -103,8 +1335,15 @@ setMethod("countDistinct", column(jc) }) -#' @rdname functions -#' @return Concatenates multiple input string columns together into a single string column. + +#' concat +#' +#' Concatenates multiple input string columns together into a single string column. +#' +#' @family string_funcs +#' @rdname concat +#' @name concat +#' @export setMethod("concat", signature(x = "Column"), function(x, ...) { @@ -113,9 +1352,15 @@ setMethod("concat", column(jc) }) -#' @rdname functions -#' @return Returns the greatest value of the list of column names, skipping null values. -#' This function takes at least 2 parameters. It will return null if all parameters are null. +#' greatest +#' +#' Returns the greatest value of the list of column names, skipping null values. +#' This function takes at least 2 parameters. It will return null if all parameters are null. +#' +#' @family normal_funcs +#' @rdname greatest +#' @name greatest +#' @export setMethod("greatest", signature(x = "Column"), function(x, ...) { @@ -125,9 +1370,15 @@ setMethod("greatest", column(jc) }) -#' @rdname functions -#' @return Returns the least value of the list of column names, skipping null values. -#' This function takes at least 2 parameters. It will return null iff all parameters are null. +#' least +#' +#' Returns the least value of the list of column names, skipping null values. +#' This function takes at least 2 parameters. It will return null iff all parameters are null. +#' +#' @family normal_funcs +#' @rdname least +#' @name least +#' @export setMethod("least", signature(x = "Column"), function(x, ...) { @@ -137,30 +1388,58 @@ setMethod("least", column(jc) }) -#' @rdname functions +#' ceiling +#' +#' Computes the ceiling of the given value. +#' +#' @family math_funcs +#' @rdname ceil +#' @name ceil #' @aliases ceil +#' @export setMethod("ceiling", signature(x = "Column"), function(x) { ceil(x) }) -#' @rdname functions +#' sign +#' +#' Computes the signum of the given value. +#' +#' @family math_funcs +#' @rdname signum +#' @name signum #' @aliases signum +#' @export setMethod("sign", signature(x = "Column"), function(x) { signum(x) }) -#' @rdname functions +#' n_distinct +#' +#' Aggregate function: returns the number of distinct items in a group. +#' +#' @family agg_funcs +#' @rdname countDistinct +#' @name countDistinct #' @aliases countDistinct +#' @export setMethod("n_distinct", signature(x = "Column"), function(x, ...) { countDistinct(x, ...) }) -#' @rdname functions +#' n +#' +#' Aggregate function: returns the number of items in a group. +#' +#' @family agg_funcs +#' @rdname count +#' @name count #' @aliases count +#' @export setMethod("n", signature(x = "Column"), function(x) { count(x) @@ -171,13 +1450,16 @@ setMethod("n", signature(x = "Column"), #' Converts a date/timestamp/string to a value of string in the format specified by the date #' format given by the second argument. #' -#' A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All -#' pattern letters of `java.text.SimpleDateFormat` can be used. +#' A pattern could be for instance \preformatted{dd.MM.yyyy} and could return a string like '18.03.1993'. All +#' pattern letters of \code{java.text.SimpleDateFormat} can be used. #' -#' NOTE: Use when ever possible specialized functions like `year`. These benefit from a +#' NOTE: Use when ever possible specialized functions like \code{year}. These benefit from a #' specialized implementation. #' -#' @rdname functions +#' @family datetime_funcs +#' @rdname date_format +#' @name date_format +#' @export setMethod("date_format", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "date_format", y@jc, x) @@ -188,7 +1470,10 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' #' Assumes given timestamp is UTC and converts to given timezone. #' -#' @rdname functions +#' @family datetime_funcs +#' @rdname from_utc_timestamp +#' @name from_utc_timestamp +#' @export setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "from_utc_timestamp", y@jc, x) @@ -203,7 +1488,10 @@ setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), #' NOTE: The position is not zero based, but 1 based index, returns 0 if substr #' could not be found in str. #' -#' @rdname functions +#' @family string_funcs +#' @rdname instr +#' @name instr +#' @export setMethod("instr", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "instr", y@jc, x) @@ -215,13 +1503,16 @@ setMethod("instr", signature(y = "Column", x = "character"), #' Given a date column, returns the first date which is later than the value of the date column #' that is on the specified day of the week. #' -#' For example, `next <- day('2015-07-27', "Sunday")` returns 2015-08-02 because that is the first +#' For example, \code{next_day('2015-07-27', "Sunday")} returns 2015-08-02 because that is the first #' Sunday after 2015-07-27. #' #' Day of the week parameter is case insensitive, and accepts: #' "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". #' -#' @rdname functions +#' @family datetime_funcs +#' @rdname next_day +#' @name next_day +#' @export setMethod("next_day", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "next_day", y@jc, x) @@ -232,7 +1523,10 @@ setMethod("next_day", signature(y = "Column", x = "character"), #' #' Assumes given timestamp is in given timezone and converts to UTC. #' -#' @rdname functions +#' @family datetime_funcs +#' @rdname to_utc_timestamp +#' @name to_utc_timestamp +#' @export setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "to_utc_timestamp", y@jc, x) @@ -243,7 +1537,11 @@ setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), #' #' Returns the date that is numMonths after startDate. #' -#' @rdname functions +#' @name add_months +#' @family datetime_funcs +#' @rdname add_months +#' @name add_months +#' @export setMethod("add_months", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "add_months", y@jc, as.integer(x)) @@ -254,7 +1552,10 @@ setMethod("add_months", signature(y = "Column", x = "numeric"), #' #' Returns the date that is `days` days after `start` #' -#' @rdname functions +#' @family datetime_funcs +#' @rdname date_add +#' @name date_add +#' @export setMethod("date_add", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "date_add", y@jc, as.integer(x)) @@ -265,7 +1566,10 @@ setMethod("date_add", signature(y = "Column", x = "numeric"), #' #' Returns the date that is `days` days before `start` #' -#' @rdname functions +#' @family datetime_funcs +#' @rdname date_sub +#' @name date_sub +#' @export setMethod("date_sub", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "date_sub", y@jc, as.integer(x)) @@ -280,7 +1584,10 @@ setMethod("date_sub", signature(y = "Column", x = "numeric"), #' If d is 0, the result has no decimal point or fractional part. #' If d < 0, the result will be null.' #' -#' @rdname functions +#' @family string_funcs +#' @rdname format_number +#' @name format_number +#' @export setMethod("format_number", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -294,9 +1601,12 @@ setMethod("format_number", signature(y = "Column", x = "numeric"), #' Calculates the SHA-2 family of hash functions of a binary column and #' returns the value as a hex string. #' -#' @rdname functions #' @param y column to compute SHA-2 on. #' @param x one of 224, 256, 384, or 512. +#' @family misc_funcs +#' @rdname sha2 +#' @name sha2 +#' @export setMethod("sha2", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "sha2", y@jc, as.integer(x)) @@ -308,7 +1618,10 @@ setMethod("sha2", signature(y = "Column", x = "numeric"), #' Shift the the given value numBits left. If the given value is a long value, this function #' will return a long value else it will return an integer value. #' -#' @rdname functions +#' @family math_funcs +#' @rdname shiftLeft +#' @name shiftLeft +#' @export setMethod("shiftLeft", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -322,7 +1635,10 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), #' Shift the the given value numBits right. If the given value is a long value, it will return #' a long value else it will return an integer value. #' -#' @rdname functions +#' @family math_funcs +#' @rdname shiftRight +#' @name shiftRight +#' @export setMethod("shiftRight", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -336,7 +1652,10 @@ setMethod("shiftRight", signature(y = "Column", x = "numeric"), #' Unsigned shift the the given value numBits right. If the given value is a long value, #' it will return a long value else it will return an integer value. #' -#' @rdname functions +#' @family math_funcs +#' @rdname shiftRightUnsigned +#' @name shiftRightUnsigned +#' @export setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -350,7 +1669,10 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), #' Concatenates multiple input string columns together into a single string column, #' using the given separator. #' -#' @rdname functions +#' @family string_funcs +#' @rdname concat_ws +#' @name concat_ws +#' @export setMethod("concat_ws", signature(sep = "character", x = "Column"), function(sep, x, ...) { jcols <- listToSeq(lapply(list(x, ...), function(x) { x@jc })) @@ -362,7 +1684,10 @@ setMethod("concat_ws", signature(sep = "character", x = "Column"), #' #' Convert a number in a string column from one base to another. #' -#' @rdname functions +#' @family math_funcs +#' @rdname conv +#' @name conv +#' @export setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeric"), function(x, fromBase, toBase) { fromBase <- as.integer(fromBase) @@ -378,7 +1703,10 @@ setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeri #' Parses the expression string into the column that it represents, similar to #' DataFrame.selectExpr #' -#' @rdname functions +#' @family normal_funcs +#' @rdname expr +#' @name expr +#' @export setMethod("expr", signature(x = "character"), function(x) { jc <- callJStatic("org.apache.spark.sql.functions", "expr", x) @@ -389,7 +1717,10 @@ setMethod("expr", signature(x = "character"), #' #' Formats the arguments in printf-style and returns the result as a string column. #' -#' @rdname functions +#' @family string_funcs +#' @rdname format_string +#' @name format_string +#' @export setMethod("format_string", signature(format = "character", x = "Column"), function(format, x, ...) { jcols <- listToSeq(lapply(list(x, ...), function(arg) { arg@jc })) @@ -405,7 +1736,10 @@ setMethod("format_string", signature(format = "character", x = "Column"), #' representing the timestamp of that moment in the current system time zone in the given #' format. #' -#' @rdname functions +#' @family datetime_funcs +#' @rdname from_unixtime +#' @name from_unixtime +#' @export setMethod("from_unixtime", signature(x = "Column"), function(x, format = "yyyy-MM-dd HH:mm:ss") { jc <- callJStatic("org.apache.spark.sql.functions", @@ -420,7 +1754,10 @@ setMethod("from_unixtime", signature(x = "Column"), #' NOTE: The position is not zero based, but 1 based index, returns 0 if substr #' could not be found in str. #' -#' @rdname functions +#' @family string_funcs +#' @rdname locate +#' @name locate +#' @export setMethod("locate", signature(substr = "character", str = "Column"), function(substr, str, pos = 0) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -433,7 +1770,10 @@ setMethod("locate", signature(substr = "character", str = "Column"), #' #' Left-pad the string column with #' -#' @rdname functions +#' @family string_funcs +#' @rdname lpad +#' @name lpad +#' @export setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -446,12 +1786,19 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), #' #' Generate a random column with i.i.d. samples from U[0.0, 1.0]. #' -#' @rdname functions +#' @family normal_funcs +#' @rdname rand +#' @name rand +#' @export setMethod("rand", signature(seed = "missing"), function(seed) { jc <- callJStatic("org.apache.spark.sql.functions", "rand") column(jc) }) +#' @family normal_funcs +#' @rdname rand +#' @name rand +#' @export setMethod("rand", signature(seed = "numeric"), function(seed) { jc <- callJStatic("org.apache.spark.sql.functions", "rand", as.integer(seed)) @@ -462,12 +1809,19 @@ setMethod("rand", signature(seed = "numeric"), #' #' Generate a column with i.i.d. samples from the standard normal distribution. #' -#' @rdname functions +#' @family normal_funcs +#' @rdname randn +#' @name randn +#' @export setMethod("randn", signature(seed = "missing"), function(seed) { jc <- callJStatic("org.apache.spark.sql.functions", "randn") column(jc) }) +#' @family normal_funcs +#' @rdname randn +#' @name randn +#' @export setMethod("randn", signature(seed = "numeric"), function(seed) { jc <- callJStatic("org.apache.spark.sql.functions", "randn", as.integer(seed)) @@ -478,7 +1832,10 @@ setMethod("randn", signature(seed = "numeric"), #' #' Extract a specific(idx) group identified by a java regex, from the specified string column. #' -#' @rdname functions +#' @family string_funcs +#' @rdname regexp_extract +#' @name regexp_extract +#' @export setMethod("regexp_extract", signature(x = "Column", pattern = "character", idx = "numeric"), function(x, pattern, idx) { @@ -492,7 +1849,10 @@ setMethod("regexp_extract", #' #' Replace all substrings of the specified string value that match regexp with rep. #' -#' @rdname functions +#' @family string_funcs +#' @rdname regexp_replace +#' @name regexp_replace +#' @export setMethod("regexp_replace", signature(x = "Column", pattern = "character", replacement = "character"), function(x, pattern, replacement) { @@ -506,7 +1866,10 @@ setMethod("regexp_replace", #' #' Right-padded with pad to a length of len. #' -#' @rdname functions +#' @family string_funcs +#' @rdname rpad +#' @name rpad +#' @export setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -522,7 +1885,10 @@ setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), #' returned. If count is negative, every to the right of the final delimiter (counting from the #' right) is returned. substring <- index performs a case-sensitive match when searching for delim. #' -#' @rdname functions +#' @family string_funcs +#' @rdname substring_index +#' @name substring_index +#' @export setMethod("substring_index", signature(x = "Column", delim = "character", count = "numeric"), function(x, delim, count) { @@ -539,7 +1905,10 @@ setMethod("substring_index", #' The translate will happen when any character in the string matching with the character #' in the matchingString. #' -#' @rdname functions +#' @family string_funcs +#' @rdname translate +#' @name translate +#' @export setMethod("translate", signature(x = "Column", matchingString = "character", replaceString = "character"), function(x, matchingString, replaceString) { @@ -552,30 +1921,28 @@ setMethod("translate", #' #' Gets current Unix timestamp in seconds. #' -#' @rdname functions +#' @family datetime_funcs +#' @rdname unix_timestamp +#' @name unix_timestamp +#' @export setMethod("unix_timestamp", signature(x = "missing", format = "missing"), function(x, format) { jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp") column(jc) }) -#' unix_timestamp -#' -#' Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), -#' using the default timezone and the default locale, return null if fail. -#' -#' @rdname functions +#' @family datetime_funcs +#' @rdname unix_timestamp +#' @name unix_timestamp +#' @export setMethod("unix_timestamp", signature(x = "Column", format = "missing"), function(x, format) { jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp", x@jc) column(jc) }) -#' unix_timestamp -#' -#' Convert time string with given pattern -#' (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) -#' to Unix time stamp (in seconds), return null if fail. -#' -#' @rdname functions +#' @family datetime_funcs +#' @rdname unix_timestamp +#' @name unix_timestamp +#' @export setMethod("unix_timestamp", signature(x = "Column", format = "character"), function(x, format = "yyyy-MM-dd HH:mm:ss") { jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp", x@jc, format) @@ -586,7 +1953,10 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"), #' Evaluates a list of conditions and returns one of multiple possible result expressions. #' For unmatched expressions null is returned. #' -#' @rdname column +#' @family normal_funcs +#' @rdname when +#' @name when +#' @export setMethod("when", signature(condition = "Column", value = "ANY"), function(condition, value) { condition <- condition@jc @@ -597,10 +1967,13 @@ setMethod("when", signature(condition = "Column", value = "ANY"), #' ifelse #' -#' Evaluates a list of conditions and returns `yes` if the conditions are satisfied. -#' Otherwise `no` is returned for unmatched conditions. +#' Evaluates a list of conditions and returns \code{yes} if the conditions are satisfied. +#' Otherwise \code{no} is returned for unmatched conditions. #' -#' @rdname column +#' @family normal_funcs +#' @rdname ifelse +#' @name ifelse +#' @export setMethod("ifelse", signature(test = "Column", yes = "ANY", no = "ANY"), function(test, yes, no) { diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 84cb8dfdaa2dd..610a8c31223cd 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -567,10 +567,6 @@ setGeneric("withColumnRenamed", ###################### Column Methods ########################## -#' @rdname column -#' @export -setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) - #' @rdname column #' @export setGeneric("asc", function(x) { standardGeneric("asc") }) @@ -587,10 +583,6 @@ setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) #' @export setGeneric("contains", function(x, ...) { standardGeneric("contains") }) -#' @rdname column -#' @export -setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) - #' @rdname column #' @export setGeneric("desc", function(x) { standardGeneric("desc") }) @@ -607,10 +599,6 @@ setGeneric("getField", function(x, ...) { standardGeneric("getField") }) #' @export setGeneric("getItem", function(x, ...) { standardGeneric("getItem") }) -#' @rdname column -#' @export -setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) - #' @rdname column #' @export setGeneric("isNull", function(x) { standardGeneric("isNull") }) @@ -619,30 +607,10 @@ setGeneric("isNull", function(x) { standardGeneric("isNull") }) #' @export setGeneric("isNotNull", function(x) { standardGeneric("isNotNull") }) -#' @rdname column -#' @export -setGeneric("last", function(x) { standardGeneric("last") }) - #' @rdname column #' @export setGeneric("like", function(x, ...) { standardGeneric("like") }) -#' @rdname column -#' @export -setGeneric("lower", function(x) { standardGeneric("lower") }) - -#' @rdname column -#' @export -setGeneric("n", function(x) { standardGeneric("n") }) - -#' @rdname column -#' @export -setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) - -#' @rdname column -#' @export -setGeneric("rint", function(x, ...) { standardGeneric("rint") }) - #' @rdname column #' @export setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) @@ -662,312 +630,340 @@ setGeneric("otherwise", function(x, value) { standardGeneric("otherwise") }) ###################### Expression Function Methods ########################## -#' @rdname functions +#' @rdname add_months #' @export setGeneric("add_months", function(y, x) { standardGeneric("add_months") }) -#' @rdname functions +#' @rdname approxCountDistinct +#' @export +setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") }) + +#' @rdname ascii #' @export setGeneric("ascii", function(x) { standardGeneric("ascii") }) -#' @rdname functions +#' @rdname avg #' @export setGeneric("avg", function(x, ...) { standardGeneric("avg") }) -#' @rdname functions +#' @rdname base64 #' @export setGeneric("base64", function(x) { standardGeneric("base64") }) -#' @rdname functions +#' @rdname bin #' @export setGeneric("bin", function(x) { standardGeneric("bin") }) -#' @rdname functions +#' @rdname bitwiseNOT #' @export setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") }) -#' @rdname functions +#' @rdname cbrt #' @export setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) -#' @rdname functions +#' @rdname ceil #' @export setGeneric("ceil", function(x) { standardGeneric("ceil") }) -#' @rdname functions +#' @rdname concat #' @export setGeneric("concat", function(x, ...) { standardGeneric("concat") }) -#' @rdname functions +#' @rdname concat_ws #' @export setGeneric("concat_ws", function(sep, x, ...) { standardGeneric("concat_ws") }) -#' @rdname functions +#' @rdname conv #' @export setGeneric("conv", function(x, fromBase, toBase) { standardGeneric("conv") }) -#' @rdname functions +#' @rdname countDistinct +#' @export +setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") }) + +#' @rdname crc32 #' @export setGeneric("crc32", function(x) { standardGeneric("crc32") }) -#' @rdname functions +#' @rdname datediff #' @export setGeneric("datediff", function(y, x) { standardGeneric("datediff") }) -#' @rdname functions +#' @rdname date_add #' @export setGeneric("date_add", function(y, x) { standardGeneric("date_add") }) -#' @rdname functions +#' @rdname date_format #' @export setGeneric("date_format", function(y, x) { standardGeneric("date_format") }) -#' @rdname functions +#' @rdname date_sub #' @export setGeneric("date_sub", function(y, x) { standardGeneric("date_sub") }) -#' @rdname functions +#' @rdname dayofmonth #' @export setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) -#' @rdname functions +#' @rdname dayofyear #' @export setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) -#' @rdname functions +#' @rdname explode #' @export setGeneric("explode", function(x) { standardGeneric("explode") }) -#' @rdname functions +#' @rdname expr #' @export setGeneric("expr", function(x) { standardGeneric("expr") }) -#' @rdname functions +#' @rdname from_utc_timestamp #' @export setGeneric("from_utc_timestamp", function(y, x) { standardGeneric("from_utc_timestamp") }) -#' @rdname functions +#' @rdname format_number #' @export setGeneric("format_number", function(y, x) { standardGeneric("format_number") }) -#' @rdname functions +#' @rdname format_string #' @export setGeneric("format_string", function(format, x, ...) { standardGeneric("format_string") }) -#' @rdname functions +#' @rdname from_unixtime #' @export setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") }) -#' @rdname functions +#' @rdname greatest #' @export setGeneric("greatest", function(x, ...) { standardGeneric("greatest") }) -#' @rdname functions +#' @rdname hex #' @export setGeneric("hex", function(x) { standardGeneric("hex") }) -#' @rdname functions +#' @rdname hour #' @export setGeneric("hour", function(x) { standardGeneric("hour") }) -#' @rdname functions +#' @rdname hypot +#' @export +setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) + +#' @rdname initcap #' @export setGeneric("initcap", function(x) { standardGeneric("initcap") }) -#' @rdname functions +#' @rdname instr #' @export setGeneric("instr", function(y, x) { standardGeneric("instr") }) -#' @rdname functions +#' @rdname isNaN #' @export setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) -#' @rdname functions +#' @rdname last +#' @export +setGeneric("last", function(x) { standardGeneric("last") }) + +#' @rdname last_day #' @export setGeneric("last_day", function(x) { standardGeneric("last_day") }) -#' @rdname functions +#' @rdname least #' @export setGeneric("least", function(x, ...) { standardGeneric("least") }) -#' @rdname functions +#' @rdname levenshtein #' @export setGeneric("levenshtein", function(y, x) { standardGeneric("levenshtein") }) -#' @rdname functions +#' @rdname lit #' @export setGeneric("lit", function(x) { standardGeneric("lit") }) -#' @rdname functions +#' @rdname locate #' @export setGeneric("locate", function(substr, str, ...) { standardGeneric("locate") }) -#' @rdname functions +#' @rdname lower #' @export setGeneric("lower", function(x) { standardGeneric("lower") }) -#' @rdname functions +#' @rdname lpad #' @export setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") }) -#' @rdname functions +#' @rdname ltrim #' @export setGeneric("ltrim", function(x) { standardGeneric("ltrim") }) -#' @rdname functions +#' @rdname md5 #' @export setGeneric("md5", function(x) { standardGeneric("md5") }) -#' @rdname functions +#' @rdname minute #' @export setGeneric("minute", function(x) { standardGeneric("minute") }) -#' @rdname functions +#' @rdname month #' @export setGeneric("month", function(x) { standardGeneric("month") }) -#' @rdname functions +#' @rdname months_between #' @export setGeneric("months_between", function(y, x) { standardGeneric("months_between") }) -#' @rdname functions +#' @rdname count +#' @export +setGeneric("n", function(x) { standardGeneric("n") }) + +#' @rdname nanvl #' @export setGeneric("nanvl", function(y, x) { standardGeneric("nanvl") }) -#' @rdname functions +#' @rdname negate #' @export setGeneric("negate", function(x) { standardGeneric("negate") }) -#' @rdname functions +#' @rdname next_day #' @export setGeneric("next_day", function(y, x) { standardGeneric("next_day") }) -#' @rdname functions +#' @rdname countDistinct +#' @export +setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) + +#' @rdname pmod #' @export setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) -#' @rdname functions +#' @rdname quarter #' @export setGeneric("quarter", function(x) { standardGeneric("quarter") }) -#' @rdname functions +#' @rdname rand #' @export setGeneric("rand", function(seed) { standardGeneric("rand") }) -#' @rdname functions +#' @rdname randn #' @export setGeneric("randn", function(seed) { standardGeneric("randn") }) -#' @rdname functions +#' @rdname regexp_extract #' @export setGeneric("regexp_extract", function(x, pattern, idx) { standardGeneric("regexp_extract") }) -#' @rdname functions +#' @rdname regexp_replace #' @export setGeneric("regexp_replace", function(x, pattern, replacement) { standardGeneric("regexp_replace") }) -#' @rdname functions +#' @rdname reverse #' @export setGeneric("reverse", function(x) { standardGeneric("reverse") }) -#' @rdname functions +#' @rdname rint +#' @export +setGeneric("rint", function(x, ...) { standardGeneric("rint") }) + +#' @rdname rpad #' @export setGeneric("rpad", function(x, len, pad) { standardGeneric("rpad") }) -#' @rdname functions +#' @rdname rtrim #' @export setGeneric("rtrim", function(x) { standardGeneric("rtrim") }) -#' @rdname functions +#' @rdname second #' @export setGeneric("second", function(x) { standardGeneric("second") }) -#' @rdname functions +#' @rdname sha1 #' @export setGeneric("sha1", function(x) { standardGeneric("sha1") }) -#' @rdname functions +#' @rdname sha2 #' @export setGeneric("sha2", function(y, x) { standardGeneric("sha2") }) -#' @rdname functions +#' @rdname shiftLeft #' @export setGeneric("shiftLeft", function(y, x) { standardGeneric("shiftLeft") }) -#' @rdname functions +#' @rdname shiftRight #' @export setGeneric("shiftRight", function(y, x) { standardGeneric("shiftRight") }) -#' @rdname functions +#' @rdname shiftRightUnsigned #' @export setGeneric("shiftRightUnsigned", function(y, x) { standardGeneric("shiftRightUnsigned") }) -#' @rdname functions +#' @rdname signum #' @export setGeneric("signum", function(x) { standardGeneric("signum") }) -#' @rdname functions +#' @rdname size #' @export setGeneric("size", function(x) { standardGeneric("size") }) -#' @rdname functions +#' @rdname soundex #' @export setGeneric("soundex", function(x) { standardGeneric("soundex") }) -#' @rdname functions +#' @rdname substring_index #' @export setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") }) -#' @rdname functions +#' @rdname sumDistinct #' @export setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) -#' @rdname functions +#' @rdname toDegrees #' @export setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") }) -#' @rdname functions +#' @rdname toRadians #' @export setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) -#' @rdname functions +#' @rdname to_date #' @export setGeneric("to_date", function(x) { standardGeneric("to_date") }) -#' @rdname functions +#' @rdname to_utc_timestamp #' @export setGeneric("to_utc_timestamp", function(y, x) { standardGeneric("to_utc_timestamp") }) -#' @rdname functions +#' @rdname translate #' @export setGeneric("translate", function(x, matchingString, replaceString) { standardGeneric("translate") }) -#' @rdname functions +#' @rdname trim #' @export setGeneric("trim", function(x) { standardGeneric("trim") }) -#' @rdname functions +#' @rdname unbase64 #' @export setGeneric("unbase64", function(x) { standardGeneric("unbase64") }) -#' @rdname functions +#' @rdname unhex #' @export setGeneric("unhex", function(x) { standardGeneric("unhex") }) -#' @rdname functions +#' @rdname unix_timestamp #' @export setGeneric("unix_timestamp", function(x, format) { standardGeneric("unix_timestamp") }) -#' @rdname functions +#' @rdname upper #' @export setGeneric("upper", function(x) { standardGeneric("upper") }) -#' @rdname functions +#' @rdname weekofyear #' @export setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") }) -#' @rdname functions +#' @rdname year #' @export setGeneric("year", function(x) { standardGeneric("year") }) From 642c43c81c835139e3f35dfd6a215d668a474203 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Mon, 24 Aug 2015 19:45:41 -0700 Subject: [PATCH 055/802] [SQL] [MINOR] [DOC] Clarify docs for inferring DataFrame from RDD of Products * Makes `SQLImplicits.rddToDataFrameHolder` scaladoc consistent with `SQLContext.createDataFrame[A <: Product](rdd: RDD[A])` since the former is essentially a wrapper for the latter * Clarifies `createDataFrame[A <: Product]` scaladoc to apply for any `RDD[Product]`, not just case classes Author: Feynman Liang Closes #8406 from feynmanliang/sql-doc-fixes. --- sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala | 2 +- sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 126c9c6f839c7..a1eea09e0477b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -350,7 +350,7 @@ class SQLContext(@transient val sparkContext: SparkContext) /** * :: Experimental :: - * Creates a DataFrame from an RDD of case classes. + * Creates a DataFrame from an RDD of Product (e.g. case classes, tuples). * * @group dataframes * @since 1.3.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 47b6f80bed483..bf03c61088426 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -40,7 +40,7 @@ private[sql] abstract class SQLImplicits { implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) /** - * Creates a DataFrame from an RDD of case classes or tuples. + * Creates a DataFrame from an RDD of Product (e.g. case classes, tuples). * @since 1.3.0 */ implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = { From a0c0aae1defe5e1e57704065631d201f8e3f6bac Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 25 Aug 2015 12:49:50 +0800 Subject: [PATCH 056/802] [SPARK-10121] [SQL] Thrift server always use the latest class loader provided by the conf of executionHive's state https://issues.apache.org/jira/browse/SPARK-10121 Looks like the problem is that if we add a jar through another thread, the thread handling the JDBC session will not get the latest classloader. Author: Yin Huai Closes #8368 from yhuai/SPARK-10121. --- .../SparkExecuteStatementOperation.scala | 6 +++ .../HiveThriftServer2Suites.scala | 54 +++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 833bf62d47d07..02cc7e5efa521 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -159,6 +159,12 @@ private[hive] class SparkExecuteStatementOperation( // User information is part of the metastore client member in Hive hiveContext.setSession(currentSqlSession) + // Always use the latest class loader provided by executionHive's state. + val executionHiveClassLoader = + hiveContext.executionHive.state.getConf.getClassLoader + sessionHive.getConf.setClassLoader(executionHiveClassLoader) + parentSessionState.getConf.setClassLoader(executionHiveClassLoader) + Hive.set(sessionHive) SessionState.setCurrentSessionState(parentSessionState) try { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index ded42bca9971e..b72249b3bf8c0 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -377,6 +377,60 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { rs2.close() } } + + test("test add jar") { + withMultipleConnectionJdbcStatement( + { + statement => + val jarFile = + "../hive/src/test/resources/hive-hcatalog-core-0.13.1.jar" + .split("/") + .mkString(File.separator) + + statement.executeQuery(s"ADD JAR $jarFile") + }, + + { + statement => + val queries = Seq( + "DROP TABLE IF EXISTS smallKV", + "CREATE TABLE smallKV(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE smallKV", + "DROP TABLE IF EXISTS addJar", + """CREATE TABLE addJar(key string) + |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe' + """.stripMargin) + + queries.foreach(statement.execute) + + statement.executeQuery( + """ + |INSERT INTO TABLE addJar SELECT 'k1' as key FROM smallKV limit 1 + """.stripMargin) + + val actualResult = + statement.executeQuery("SELECT key FROM addJar") + val actualResultBuffer = new collection.mutable.ArrayBuffer[String]() + while (actualResult.next()) { + actualResultBuffer += actualResult.getString(1) + } + actualResult.close() + + val expectedResult = + statement.executeQuery("SELECT 'k1'") + val expectedResultBuffer = new collection.mutable.ArrayBuffer[String]() + while (expectedResult.next()) { + expectedResultBuffer += expectedResult.getString(1) + } + expectedResult.close() + + assert(expectedResultBuffer === actualResultBuffer) + + statement.executeQuery("DROP TABLE IF EXISTS addJar") + statement.executeQuery("DROP TABLE IF EXISTS smallKV") + } + ) + } } class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { From 5175ca0c85b10045d12c3fb57b1e52278a413ecf Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 24 Aug 2015 23:15:27 -0700 Subject: [PATCH 057/802] [SPARK-10178] [SQL] HiveComparisionTest should print out dependent tables In `HiveComparisionTest`s it is possible to fail a query of the form `SELECT * FROM dest1`, where `dest1` is the query that is actually computing the incorrect results. To aid debugging this patch improves the harness to also print these query plans and their results. Author: Michael Armbrust Closes #8388 from marmbrus/generatedTables. --- .../hive/execution/HiveComparisonTest.scala | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 2bdb0e11878e5..4d45249d9c6b8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.hive.execution import java.io._ +import scala.util.control.NonFatal + import org.scalatest.{BeforeAndAfterAll, GivenWhenThen} import org.apache.spark.{Logging, SparkFunSuite} @@ -386,11 +388,45 @@ abstract class HiveComparisonTest hiveCacheFiles.foreach(_.delete()) } + // If this query is reading other tables that were created during this test run + // also print out the query plans and results for those. + val computedTablesMessages: String = try { + val tablesRead = new TestHive.QueryExecution(query).executedPlan.collect { + case ts: HiveTableScan => ts.relation.tableName + }.toSet + + TestHive.reset() + val executions = queryList.map(new TestHive.QueryExecution(_)) + executions.foreach(_.toRdd) + val tablesGenerated = queryList.zip(executions).flatMap { + case (q, e) => e.executedPlan.collect { + case i: InsertIntoHiveTable if tablesRead contains i.table.tableName => + (q, e, i) + } + } + + tablesGenerated.map { case (hiveql, execution, insert) => + s""" + |=== Generated Table === + |$hiveql + |$execution + |== Results == + |${insert.child.execute().collect().mkString("\n")} + """.stripMargin + }.mkString("\n") + + } catch { + case NonFatal(e) => + logError("Failed to compute generated tables", e) + s"Couldn't compute dependent tables: $e" + } + val errorMessage = s""" |Results do not match for $testCaseName: |$hiveQuery\n${hiveQuery.analyzed.output.map(_.name).mkString("\t")} |$resultComparison + |$computedTablesMessages """.stripMargin stringToFile(new File(wrongDirectory, testCaseName), errorMessage + consoleTestCase) From d9c25dec87e6da7d66a47ff94e7eefa008081b9d Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Mon, 24 Aug 2015 23:26:14 -0700 Subject: [PATCH 058/802] =?UTF-8?q?[SPARK-9786]=20[STREAMING]=20[KAFKA]=20?= =?UTF-8?q?fix=20backpressure=20so=20it=20works=20with=20defa=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ult maxRatePerPartition setting of 0 Author: cody koeninger Closes #8413 from koeninger/backpressure-testing-master. --- .../spark/streaming/kafka/DirectKafkaInputDStream.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 8a177077775c6..1000094e93cb3 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -95,8 +95,13 @@ class DirectKafkaInputDStream[ val effectiveRateLimitPerPartition = estimatedRateLimit .filter(_ > 0) - .map(limit => Math.min(maxRateLimitPerPartition, (limit / numPartitions))) - .getOrElse(maxRateLimitPerPartition) + .map { limit => + if (maxRateLimitPerPartition > 0) { + Math.min(maxRateLimitPerPartition, (limit / numPartitions)) + } else { + limit / numPartitions + } + }.getOrElse(maxRateLimitPerPartition) if (effectiveRateLimitPerPartition > 0) { val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000 From f023aa2fcc1d1dbb82aee568be0a8f2457c309ae Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 24 Aug 2015 23:34:50 -0700 Subject: [PATCH 059/802] [SPARK-10137] [STREAMING] Avoid to restart receivers if scheduleReceivers returns balanced results This PR fixes the following cases for `ReceiverSchedulingPolicy`. 1) Assume there are 4 executors: host1, host2, host3, host4, and 5 receivers: r1, r2, r3, r4, r5. Then `ReceiverSchedulingPolicy.scheduleReceivers` will return (r1 -> host1, r2 -> host2, r3 -> host3, r4 -> host4, r5 -> host1). Let's assume r1 starts at first on `host1` as `scheduleReceivers` suggested, and try to register with ReceiverTracker. But the previous `ReceiverSchedulingPolicy.rescheduleReceiver` will return (host2, host3, host4) according to the current executor weights (host1 -> 1.0, host2 -> 0.5, host3 -> 0.5, host4 -> 0.5), so ReceiverTracker will reject `r1`. This is unexpected since r1 is starting exactly where `scheduleReceivers` suggested. This case can be fixed by ignoring the information of the receiver that is rescheduling in `receiverTrackingInfoMap`. 2) Assume there are 3 executors (host1, host2, host3) and each executors has 3 cores, and 3 receivers: r1, r2, r3. Assume r1 is running on host1. Now r2 is restarting, the previous `ReceiverSchedulingPolicy.rescheduleReceiver` will always return (host1, host2, host3). So it's possible that r2 will be scheduled to host1 by TaskScheduler. r3 is similar. Then at last, it's possible that there are 3 receivers running on host1, while host2 and host3 are idle. This issue can be fixed by returning only executors that have the minimum wight rather than returning at least 3 executors. Author: zsxwing Closes #8340 from zsxwing/fix-receiver-scheduling. --- .../scheduler/ReceiverSchedulingPolicy.scala | 58 +++++++--- .../streaming/scheduler/ReceiverTracker.scala | 106 ++++++++++++------ .../ReceiverSchedulingPolicySuite.scala | 13 ++- 3 files changed, 120 insertions(+), 57 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala index ef5b687b5831a..10b5a7f57a802 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala @@ -22,6 +22,36 @@ import scala.collection.mutable import org.apache.spark.streaming.receiver.Receiver +/** + * A class that tries to schedule receivers with evenly distributed. There are two phases for + * scheduling receivers. + * + * - The first phase is global scheduling when ReceiverTracker is starting and we need to schedule + * all receivers at the same time. ReceiverTracker will call `scheduleReceivers` at this phase. + * It will try to schedule receivers with evenly distributed. ReceiverTracker should update its + * receiverTrackingInfoMap according to the results of `scheduleReceivers`. + * `ReceiverTrackingInfo.scheduledExecutors` for each receiver will set to an executor list that + * contains the scheduled locations. Then when a receiver is starting, it will send a register + * request and `ReceiverTracker.registerReceiver` will be called. In + * `ReceiverTracker.registerReceiver`, if a receiver's scheduled executors is set, it should check + * if the location of this receiver is one of the scheduled executors, if not, the register will + * be rejected. + * - The second phase is local scheduling when a receiver is restarting. There are two cases of + * receiver restarting: + * - If a receiver is restarting because it's rejected due to the real location and the scheduled + * executors mismatching, in other words, it fails to start in one of the locations that + * `scheduleReceivers` suggested, `ReceiverTracker` should firstly choose the executors that are + * still alive in the list of scheduled executors, then use them to launch the receiver job. + * - If a receiver is restarting without a scheduled executors list, or the executors in the list + * are dead, `ReceiverTracker` should call `rescheduleReceiver`. If so, `ReceiverTracker` should + * not set `ReceiverTrackingInfo.scheduledExecutors` for this executor, instead, it should clear + * it. Then when this receiver is registering, we can know this is a local scheduling, and + * `ReceiverTrackingInfo` should call `rescheduleReceiver` again to check if the launching + * location is matching. + * + * In conclusion, we should make a global schedule, try to achieve that exactly as long as possible, + * otherwise do local scheduling. + */ private[streaming] class ReceiverSchedulingPolicy { /** @@ -102,8 +132,7 @@ private[streaming] class ReceiverSchedulingPolicy { /** * Return a list of candidate executors to run the receiver. If the list is empty, the caller can - * run this receiver in arbitrary executor. The caller can use `preferredNumExecutors` to require - * returning `preferredNumExecutors` executors if possible. + * run this receiver in arbitrary executor. * * This method tries to balance executors' load. Here is the approach to schedule executors * for a receiver. @@ -122,9 +151,8 @@ private[streaming] class ReceiverSchedulingPolicy { * If a receiver is scheduled to an executor but has not yet run, it contributes * `1.0 / #candidate_executors_of_this_receiver` to the executor's weight. * - * At last, if there are more than `preferredNumExecutors` idle executors (weight = 0), - * returns all idle executors. Otherwise, we only return `preferredNumExecutors` best options - * according to the weights. + * At last, if there are any idle executors (weight = 0), returns all idle executors. + * Otherwise, returns the executors that have the minimum weight. * * * @@ -134,8 +162,7 @@ private[streaming] class ReceiverSchedulingPolicy { receiverId: Int, preferredLocation: Option[String], receiverTrackingInfoMap: Map[Int, ReceiverTrackingInfo], - executors: Seq[String], - preferredNumExecutors: Int = 3): Seq[String] = { + executors: Seq[String]): Seq[String] = { if (executors.isEmpty) { return Seq.empty } @@ -156,15 +183,18 @@ private[streaming] class ReceiverSchedulingPolicy { } }.groupBy(_._1).mapValues(_.map(_._2).sum) // Sum weights for each executor - val idleExecutors = (executors.toSet -- executorWeights.keys).toSeq - if (idleExecutors.size >= preferredNumExecutors) { - // If there are more than `preferredNumExecutors` idle executors, return all of them + val idleExecutors = executors.toSet -- executorWeights.keys + if (idleExecutors.nonEmpty) { scheduledExecutors ++= idleExecutors } else { - // If there are less than `preferredNumExecutors` idle executors, return 3 best options - scheduledExecutors ++= idleExecutors - val sortedExecutors = executorWeights.toSeq.sortBy(_._2).map(_._1) - scheduledExecutors ++= (idleExecutors ++ sortedExecutors).take(preferredNumExecutors) + // There is no idle executor. So select all executors that have the minimum weight. + val sortedExecutors = executorWeights.toSeq.sortBy(_._2) + if (sortedExecutors.nonEmpty) { + val minWeight = sortedExecutors(0)._2 + scheduledExecutors ++= sortedExecutors.takeWhile(_._2 == minWeight).map(_._1) + } else { + // This should not happen since "executors" is not empty + } } scheduledExecutors.toSeq } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 30d25a64e307a..3d532a675db02 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -244,8 +244,21 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } if (isTrackerStopping || isTrackerStopped) { - false - } else if (!scheduleReceiver(streamId).contains(hostPort)) { + return false + } + + val scheduledExecutors = receiverTrackingInfos(streamId).scheduledExecutors + val accetableExecutors = if (scheduledExecutors.nonEmpty) { + // This receiver is registering and it's scheduled by + // ReceiverSchedulingPolicy.scheduleReceivers. So use "scheduledExecutors" to check it. + scheduledExecutors.get + } else { + // This receiver is scheduled by "ReceiverSchedulingPolicy.rescheduleReceiver", so calling + // "ReceiverSchedulingPolicy.rescheduleReceiver" again to check it. + scheduleReceiver(streamId) + } + + if (!accetableExecutors.contains(hostPort)) { // Refuse it since it's scheduled to a wrong executor false } else { @@ -426,12 +439,25 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false startReceiver(receiver, executors) } case RestartReceiver(receiver) => - val scheduledExecutors = schedulingPolicy.rescheduleReceiver( - receiver.streamId, - receiver.preferredLocation, - receiverTrackingInfos, - getExecutors) - updateReceiverScheduledExecutors(receiver.streamId, scheduledExecutors) + // Old scheduled executors minus the ones that are not active any more + val oldScheduledExecutors = getStoredScheduledExecutors(receiver.streamId) + val scheduledExecutors = if (oldScheduledExecutors.nonEmpty) { + // Try global scheduling again + oldScheduledExecutors + } else { + val oldReceiverInfo = receiverTrackingInfos(receiver.streamId) + // Clear "scheduledExecutors" to indicate we are going to do local scheduling + val newReceiverInfo = oldReceiverInfo.copy( + state = ReceiverState.INACTIVE, scheduledExecutors = None) + receiverTrackingInfos(receiver.streamId) = newReceiverInfo + schedulingPolicy.rescheduleReceiver( + receiver.streamId, + receiver.preferredLocation, + receiverTrackingInfos, + getExecutors) + } + // Assume there is one receiver restarting at one time, so we don't need to update + // receiverTrackingInfos startReceiver(receiver, scheduledExecutors) case c: CleanupOldBlocks => receiverTrackingInfos.values.flatMap(_.endpoint).foreach(_.send(c)) @@ -464,6 +490,24 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false context.reply(true) } + /** + * Return the stored scheduled executors that are still alive. + */ + private def getStoredScheduledExecutors(receiverId: Int): Seq[String] = { + if (receiverTrackingInfos.contains(receiverId)) { + val scheduledExecutors = receiverTrackingInfos(receiverId).scheduledExecutors + if (scheduledExecutors.nonEmpty) { + val executors = getExecutors.toSet + // Only return the alive executors + scheduledExecutors.get.filter(executors) + } else { + Nil + } + } else { + Nil + } + } + /** * Start a receiver along with its scheduled executors */ @@ -484,7 +528,23 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false new SerializableConfiguration(ssc.sparkContext.hadoopConfiguration) // Function to start the receiver on the worker node - val startReceiverFunc = new StartReceiverFunc(checkpointDirOption, serializableHadoopConf) + val startReceiverFunc: Iterator[Receiver[_]] => Unit = + (iterator: Iterator[Receiver[_]]) => { + if (!iterator.hasNext) { + throw new SparkException( + "Could not start receiver as object not found.") + } + if (TaskContext.get().attemptNumber() == 0) { + val receiver = iterator.next() + assert(iterator.hasNext == false) + val supervisor = new ReceiverSupervisorImpl( + receiver, SparkEnv.get, serializableHadoopConf.value, checkpointDirOption) + supervisor.start() + supervisor.awaitTermination() + } else { + // It's restarted by TaskScheduler, but we want to reschedule it again. So exit it. + } + } // Create the RDD using the scheduledExecutors to run the receiver in a Spark job val receiverRDD: RDD[Receiver[_]] = @@ -541,31 +601,3 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } } - -/** - * Function to start the receiver on the worker node. Use a class instead of closure to avoid - * the serialization issue. - */ -private[streaming] class StartReceiverFunc( - checkpointDirOption: Option[String], - serializableHadoopConf: SerializableConfiguration) - extends (Iterator[Receiver[_]] => Unit) with Serializable { - - override def apply(iterator: Iterator[Receiver[_]]): Unit = { - if (!iterator.hasNext) { - throw new SparkException( - "Could not start receiver as object not found.") - } - if (TaskContext.get().attemptNumber() == 0) { - val receiver = iterator.next() - assert(iterator.hasNext == false) - val supervisor = new ReceiverSupervisorImpl( - receiver, SparkEnv.get, serializableHadoopConf.value, checkpointDirOption) - supervisor.start() - supervisor.awaitTermination() - } else { - // It's restarted by TaskScheduler, but we want to reschedule it again. So exit it. - } - } - -} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala index 0418d776ecc9a..b2a51d72bac2b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala @@ -39,7 +39,7 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite { assert(scheduledExecutors.toSet === Set("host1", "host2")) } - test("rescheduleReceiver: return all idle executors if more than 3 idle executors") { + test("rescheduleReceiver: return all idle executors if there are any idle executors") { val executors = Seq("host1", "host2", "host3", "host4", "host5") // host3 is idle val receiverTrackingInfoMap = Map( @@ -49,16 +49,16 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite { assert(scheduledExecutors.toSet === Set("host2", "host3", "host4", "host5")) } - test("rescheduleReceiver: return 3 best options if less than 3 idle executors") { + test("rescheduleReceiver: return all executors that have minimum weight if no idle executors") { val executors = Seq("host1", "host2", "host3", "host4", "host5") - // Weights: host1 = 1.5, host2 = 0.5, host3 = 1.0 - // host4 and host5 are idle + // Weights: host1 = 1.5, host2 = 0.5, host3 = 1.0, host4 = 0.5, host5 = 0.5 val receiverTrackingInfoMap = Map( 0 -> ReceiverTrackingInfo(0, ReceiverState.ACTIVE, None, Some("host1")), 1 -> ReceiverTrackingInfo(1, ReceiverState.SCHEDULED, Some(Seq("host2", "host3")), None), - 2 -> ReceiverTrackingInfo(1, ReceiverState.SCHEDULED, Some(Seq("host1", "host3")), None)) + 2 -> ReceiverTrackingInfo(2, ReceiverState.SCHEDULED, Some(Seq("host1", "host3")), None), + 3 -> ReceiverTrackingInfo(4, ReceiverState.SCHEDULED, Some(Seq("host4", "host5")), None)) val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver( - 3, None, receiverTrackingInfoMap, executors) + 4, None, receiverTrackingInfoMap, executors) assert(scheduledExecutors.toSet === Set("host2", "host4", "host5")) } @@ -127,4 +127,5 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite { assert(executors.isEmpty) } } + } From df7041d02d3fd44b08a859f5d77bf6fb726895f0 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 24 Aug 2015 23:38:32 -0700 Subject: [PATCH 060/802] [SPARK-10196] [SQL] Correctly saving decimals in internal rows to JSON. https://issues.apache.org/jira/browse/SPARK-10196 Author: Yin Huai Closes #8408 from yhuai/DecimalJsonSPARK-10196. --- .../datasources/json/JacksonGenerator.scala | 2 +- .../sources/JsonHadoopFsRelationSuite.scala | 27 +++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index 99ac7730bd1c9..330ba907b2ef9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -95,7 +95,7 @@ private[sql] object JacksonGenerator { case (FloatType, v: Float) => gen.writeNumber(v) case (DoubleType, v: Double) => gen.writeNumber(v) case (LongType, v: Long) => gen.writeNumber(v) - case (DecimalType(), v: java.math.BigDecimal) => gen.writeNumber(v) + case (DecimalType(), v: Decimal) => gen.writeNumber(v.toJavaBigDecimal) case (ByteType, v: Byte) => gen.writeNumber(v.toInt) case (BinaryType, v: Array[Byte]) => gen.writeBinary(v) case (BooleanType, v: Boolean) => gen.writeBoolean(v) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala index ed6d512ab36fe..8ca3a17085194 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.sources +import java.math.BigDecimal + import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil @@ -75,4 +77,29 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { ) } } + + test("SPARK-10196: save decimal type to JSON") { + withTempDir { file => + file.delete() + + val schema = + new StructType() + .add("decimal", DecimalType(7, 2)) + + val data = + Row(new BigDecimal("10.02")) :: + Row(new BigDecimal("20000.99")) :: + Row(new BigDecimal("10000")) :: Nil + val df = createDataFrame(sparkContext.parallelize(data), schema) + + // Write the data out. + df.write.format(dataSourceName).save(file.getCanonicalPath) + + // Read it back and check the result. + checkAnswer( + read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), + df + ) + } + } } From bf03fe68d62f33dda70dff45c3bda1f57b032dfc Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 25 Aug 2015 14:58:42 +0800 Subject: [PATCH 061/802] [SPARK-10136] [SQL] A more robust fix for SPARK-10136 PR #8341 is a valid fix for SPARK-10136, but it didn't catch the real root cause. The real problem can be rather tricky to explain, and requires audiences to be pretty familiar with parquet-format spec, especially details of `LIST` backwards-compatibility rules. Let me have a try to give an explanation here. The structure of the problematic Parquet schema generated by parquet-avro is something like this: ``` message m { group f (LIST) { // Level 1 repeated group array (LIST) { // Level 2 repeated array; // Level 3 } } } ``` (The schema generated by parquet-thrift is structurally similar, just replace the `array` at level 2 with `f_tuple`, and the other one at level 3 with `f_tuple_tuple`.) This structure consists of two nested legacy 2-level `LIST`-like structures: 1. The repeated group type at level 2 is the element type of the outer array defined at level 1 This group should map to an `CatalystArrayConverter.ElementConverter` when building converters. 2. The repeated primitive type at level 3 is the element type of the inner array defined at level 2 This group should also map to an `CatalystArrayConverter.ElementConverter`. The root cause of SPARK-10136 is that, the group at level 2 isn't properly recognized as the element type of level 1. Thus, according to parquet-format spec, the repeated primitive at level 3 is left as a so called "unannotated repeated primitive type", and is recognized as a required list of required primitive type, thus a `RepeatedPrimitiveConverter` instead of a `CatalystArrayConverter.ElementConverter` is created for it. According to parquet-format spec, unannotated repeated type shouldn't appear in a `LIST`- or `MAP`-annotated group. PR #8341 fixed this issue by allowing such unannotated repeated type appear in `LIST`-annotated groups, which is a non-standard, hacky, but valid fix. (I didn't realize this when authoring #8341 though.) As for the reason why level 2 isn't recognized as a list element type, it's because of the following `LIST` backwards-compatibility rule defined in the parquet-format spec: > If the repeated field is a group with one field and is named either `array` or uses the `LIST`-annotated group's name with `_tuple` appended then the repeated type is the element type and elements are required. (The `array` part is for parquet-avro compatibility, while the `_tuple` part is for parquet-thrift.) This rule is implemented in [`CatalystSchemaConverter.isElementType`] [1], but neglected in [`CatalystRowConverter.isElementType`] [2]. This PR delivers a more robust fix by adding this rule in the latter method. Note that parquet-avro 1.7.0 also suffers from this issue. Details can be found at [PARQUET-364] [3]. [1]: https://github.com/apache/spark/blob/85f9a61357994da5023b08b0a8a2eb09388ce7f8/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala#L259-L305 [2]: https://github.com/apache/spark/blob/85f9a61357994da5023b08b0a8a2eb09388ce7f8/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala#L456-L463 [3]: https://issues.apache.org/jira/browse/PARQUET-364 Author: Cheng Lian Closes #8361 from liancheng/spark-10136/proper-version. --- .../parquet/CatalystRowConverter.scala | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index d2c2db51769ba..cbf0704c4a9a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -415,8 +415,9 @@ private[parquet] class CatalystRowConverter( private val elementConverter: Converter = { val repeatedType = parquetSchema.getType(0) val elementType = catalystSchema.elementType + val parentName = parquetSchema.getName - if (isElementType(repeatedType, elementType)) { + if (isElementType(repeatedType, elementType, parentName)) { newConverter(repeatedType, elementType, new ParentContainerUpdater { override def set(value: Any): Unit = currentArray += value }) @@ -453,10 +454,13 @@ private[parquet] class CatalystRowConverter( * @see https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules */ // scalastyle:on - private def isElementType(parquetRepeatedType: Type, catalystElementType: DataType): Boolean = { + private def isElementType( + parquetRepeatedType: Type, catalystElementType: DataType, parentName: String): Boolean = { (parquetRepeatedType, catalystElementType) match { case (t: PrimitiveType, _) => true case (t: GroupType, _) if t.getFieldCount > 1 => true + case (t: GroupType, _) if t.getFieldCount == 1 && t.getName == "array" => true + case (t: GroupType, _) if t.getFieldCount == 1 && t.getName == parentName + "_tuple" => true case (t: GroupType, StructType(Array(f))) if f.name == t.getFieldName(0) => true case _ => false } @@ -474,15 +478,9 @@ private[parquet] class CatalystRowConverter( override def getConverter(fieldIndex: Int): Converter = converter - override def end(): Unit = { - converter.updater.end() - currentArray += currentElement - } + override def end(): Unit = currentArray += currentElement - override def start(): Unit = { - converter.updater.start() - currentElement = null - } + override def start(): Unit = currentElement = null } } From 82268f07abfa658869df2354ae72f8d6ddd119e8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 25 Aug 2015 00:04:10 -0700 Subject: [PATCH 062/802] [SPARK-9293] [SPARK-9813] Analysis should check that set operations are only performed on tables with equal numbers of columns This patch adds an analyzer rule to ensure that set operations (union, intersect, and except) are only applied to tables with the same number of columns. Without this rule, there are scenarios where invalid queries can return incorrect results instead of failing with error messages; SPARK-9813 provides one example of this problem. In other cases, the invalid query can crash at runtime with extremely confusing exceptions. I also performed a bit of cleanup to refactor some of those logical operators' code into a common `SetOperation` base class. Author: Josh Rosen Closes #7631 from JoshRosen/SPARK-9293. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 6 +++ .../catalyst/analysis/HiveTypeCoercion.scala | 14 +++---- .../plans/logical/basicOperators.scala | 38 +++++++++---------- .../analysis/AnalysisErrorSuite.scala | 18 +++++++++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- .../hive/execution/InsertIntoHiveTable.scala | 2 +- 6 files changed, 48 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 39f554c137c98..7701fd0451041 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -137,6 +137,12 @@ trait CheckAnalysis { } } + case s @ SetOperation(left, right) if left.output.length != right.output.length => + failAnalysis( + s"${s.nodeName} can only be performed on tables with the same number of columns, " + + s"but the left table has ${left.output.length} columns and the right has " + + s"${right.output.length}") + case _ => // Fallbacks to the following checks } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 2cb067f4aac91..a1aa2a2b2c680 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -203,6 +203,7 @@ object HiveTypeCoercion { planName: String, left: LogicalPlan, right: LogicalPlan): (LogicalPlan, LogicalPlan) = { + require(left.output.length == right.output.length) val castedTypes = left.output.zip(right.output).map { case (lhs, rhs) if lhs.dataType != rhs.dataType => @@ -229,15 +230,10 @@ object HiveTypeCoercion { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p if p.analyzed => p - case u @ Union(left, right) if u.childrenResolved && !u.resolved => - val (newLeft, newRight) = widenOutputTypes(u.nodeName, left, right) - Union(newLeft, newRight) - case e @ Except(left, right) if e.childrenResolved && !e.resolved => - val (newLeft, newRight) = widenOutputTypes(e.nodeName, left, right) - Except(newLeft, newRight) - case i @ Intersect(left, right) if i.childrenResolved && !i.resolved => - val (newLeft, newRight) = widenOutputTypes(i.nodeName, left, right) - Intersect(newLeft, newRight) + case s @ SetOperation(left, right) if s.childrenResolved + && left.output.length == right.output.length && !s.resolved => + val (newLeft, newRight) = widenOutputTypes(s.nodeName, left, right) + s.makeCopy(Array(newLeft, newRight)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 73b8261260acb..722f69cdca827 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -89,13 +89,21 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } -case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { +abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { // TODO: These aren't really the same attributes as nullability etc might change. - override def output: Seq[Attribute] = left.output + final override def output: Seq[Attribute] = left.output - override lazy val resolved: Boolean = + final override lazy val resolved: Boolean = childrenResolved && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } + left.output.length == right.output.length && + left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } +} + +private[sql] object SetOperation { + def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right)) +} + +case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { override def statistics: Statistics = { val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes @@ -103,6 +111,10 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { } } +case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) + +case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) + case class Join( left: LogicalPlan, right: LogicalPlan, @@ -142,15 +154,6 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } - -case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - override def output: Seq[Attribute] = left.output - - override lazy val resolved: Boolean = - childrenResolved && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } -} - case class InsertIntoTable( table: LogicalPlan, partition: Map[String, Option[String]], @@ -160,7 +163,7 @@ case class InsertIntoTable( extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil - override def output: Seq[Attribute] = child.output + override def output: Seq[Attribute] = Seq.empty assert(overwrite || !ifNotExists) override lazy val resolved: Boolean = childrenResolved && child.output.zip(table.output).forall { @@ -440,10 +443,3 @@ case object OneRowRelation extends LeafNode { override def statistics: Statistics = Statistics(sizeInBytes = 1) } -case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - override def output: Seq[Attribute] = left.output - - override lazy val resolved: Boolean = - childrenResolved && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 7065adce04bf8..fbdd3a7776f50 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -145,6 +145,24 @@ class AnalysisErrorSuite extends AnalysisTest { UnresolvedTestPlan(), "unresolved" :: Nil) + errorTest( + "union with unequal number of columns", + testRelation.unionAll(testRelation2), + "union" :: "number of columns" :: testRelation2.output.length.toString :: + testRelation.output.length.toString :: Nil) + + errorTest( + "intersect with unequal number of columns", + testRelation.intersect(testRelation2), + "intersect" :: "number of columns" :: testRelation2.output.length.toString :: + testRelation.output.length.toString :: Nil) + + errorTest( + "except with unequal number of columns", + testRelation.except(testRelation2), + "except" :: "number of columns" :: testRelation2.output.length.toString :: + testRelation.output.length.toString :: Nil) + errorTest( "SPARK-9955: correct error message for aggregate", // When parse SQL string, we will wrap aggregate expressions with UnresolvedAlias. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index bbe8c1911bf86..98d21aa76d64e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -751,7 +751,7 @@ private[hive] case class InsertIntoHiveTable( extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil - override def output: Seq[Attribute] = child.output + override def output: Seq[Attribute] = Seq.empty val numDynamicPartitions = partition.values.count(_.isEmpty) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 12c667e6e92da..62efda613a176 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -61,7 +61,7 @@ case class InsertIntoHiveTable( serializer } - def output: Seq[Attribute] = child.output + def output: Seq[Attribute] = Seq.empty def saveAsHiveFile( rdd: RDD[InternalRow], From d4549fe58fa0d781e0e891bceff893420cb1d598 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Tue, 25 Aug 2015 00:28:51 -0700 Subject: [PATCH 063/802] [SPARK-10214] [SPARKR] [DOCS] Improve SparkR Column, DataFrame API docs cc: shivaram ## Summary - Add name tags to each methods in DataFrame.R and column.R - Replace `rdname column` with `rdname {each_func}`. i.e. alias method : `rdname column` => `rdname alias` ## Generated PDF File https://drive.google.com/file/d/0B9biIZIU47lLNHN2aFpnQXlSeGs/view?usp=sharing ## JIRA [[SPARK-10214] Improve SparkR Column, DataFrame API docs - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-10214) Author: Yu ISHIKAWA Closes #8414 from yu-iskw/SPARK-10214. --- R/pkg/R/DataFrame.R | 101 +++++++++++++++++++++++++++++++++++--------- R/pkg/R/column.R | 40 ++++++++++++------ R/pkg/R/generics.R | 2 +- 3 files changed, 109 insertions(+), 34 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 895603235011e..10f3c4ea59864 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -27,9 +27,10 @@ setOldClass("jobj") #' \code{jsonFile}, \code{table} etc. #' @rdname DataFrame #' @seealso jsonFile, table +#' @docType class #' -#' @param env An R environment that stores bookkeeping states of the DataFrame -#' @param sdf A Java object reference to the backing Scala DataFrame +#' @slot env An R environment that stores bookkeeping states of the DataFrame +#' @slot sdf A Java object reference to the backing Scala DataFrame #' @export setClass("DataFrame", slots = list(env = "environment", @@ -61,6 +62,7 @@ dataFrame <- function(sdf, isCached = FALSE) { #' @param x A SparkSQL DataFrame #' #' @rdname printSchema +#' @name printSchema #' @export #' @examples #'\dontrun{ @@ -84,6 +86,7 @@ setMethod("printSchema", #' @param x A SparkSQL DataFrame #' #' @rdname schema +#' @name schema #' @export #' @examples #'\dontrun{ @@ -106,6 +109,7 @@ setMethod("schema", #' @param x A SparkSQL DataFrame #' @param extended Logical. If extended is False, explain() only prints the physical plan. #' @rdname explain +#' @name explain #' @export #' @examples #'\dontrun{ @@ -135,6 +139,7 @@ setMethod("explain", #' @param x A SparkSQL DataFrame #' #' @rdname isLocal +#' @name isLocal #' @export #' @examples #'\dontrun{ @@ -158,6 +163,7 @@ setMethod("isLocal", #' @param numRows The number of rows to print. Defaults to 20. #' #' @rdname showDF +#' @name showDF #' @export #' @examples #'\dontrun{ @@ -181,6 +187,7 @@ setMethod("showDF", #' @param x A SparkSQL DataFrame #' #' @rdname show +#' @name show #' @export #' @examples #'\dontrun{ @@ -206,6 +213,7 @@ setMethod("show", "DataFrame", #' @param x A SparkSQL DataFrame #' #' @rdname dtypes +#' @name dtypes #' @export #' @examples #'\dontrun{ @@ -230,6 +238,8 @@ setMethod("dtypes", #' @param x A SparkSQL DataFrame #' #' @rdname columns +#' @name columns +#' @aliases names #' @export #' @examples #'\dontrun{ @@ -248,7 +258,7 @@ setMethod("columns", }) #' @rdname columns -#' @aliases names,DataFrame,function-method +#' @name names setMethod("names", signature(x = "DataFrame"), function(x) { @@ -256,6 +266,7 @@ setMethod("names", }) #' @rdname columns +#' @name names<- setMethod("names<-", signature(x = "DataFrame"), function(x, value) { @@ -273,6 +284,7 @@ setMethod("names<-", #' @param tableName A character vector containing the name of the table #' #' @rdname registerTempTable +#' @name registerTempTable #' @export #' @examples #'\dontrun{ @@ -299,6 +311,7 @@ setMethod("registerTempTable", #' the existing rows in the table. #' #' @rdname insertInto +#' @name insertInto #' @export #' @examples #'\dontrun{ @@ -321,7 +334,8 @@ setMethod("insertInto", #' #' @param x A SparkSQL DataFrame #' -#' @rdname cache-methods +#' @rdname cache +#' @name cache #' @export #' @examples #'\dontrun{ @@ -347,6 +361,7 @@ setMethod("cache", #' #' @param x The DataFrame to persist #' @rdname persist +#' @name persist #' @export #' @examples #'\dontrun{ @@ -372,6 +387,7 @@ setMethod("persist", #' @param x The DataFrame to unpersist #' @param blocking Whether to block until all blocks are deleted #' @rdname unpersist-methods +#' @name unpersist #' @export #' @examples #'\dontrun{ @@ -397,6 +413,7 @@ setMethod("unpersist", #' @param x A SparkSQL DataFrame #' @param numPartitions The number of partitions to use. #' @rdname repartition +#' @name repartition #' @export #' @examples #'\dontrun{ @@ -446,6 +463,7 @@ setMethod("toJSON", #' @param x A SparkSQL DataFrame #' @param path The directory where the file is saved #' @rdname saveAsParquetFile +#' @name saveAsParquetFile #' @export #' @examples #'\dontrun{ @@ -467,6 +485,7 @@ setMethod("saveAsParquetFile", #' #' @param x A SparkSQL DataFrame #' @rdname distinct +#' @name distinct #' @export #' @examples #'\dontrun{ @@ -488,7 +507,8 @@ setMethod("distinct", #' @description Returns a new DataFrame containing distinct rows in this DataFrame #' #' @rdname unique -#' @aliases unique +#' @name unique +#' @aliases distinct setMethod("unique", signature(x = "DataFrame"), function(x) { @@ -526,7 +546,7 @@ setMethod("sample", }) #' @rdname sample -#' @aliases sample +#' @name sample_frac setMethod("sample_frac", signature(x = "DataFrame", withReplacement = "logical", fraction = "numeric"), @@ -541,6 +561,8 @@ setMethod("sample_frac", #' @param x A SparkSQL DataFrame #' #' @rdname count +#' @name count +#' @aliases nrow #' @export #' @examples #'\dontrun{ @@ -574,6 +596,7 @@ setMethod("nrow", #' @param x a SparkSQL DataFrame #' #' @rdname ncol +#' @name ncol #' @export #' @examples #'\dontrun{ @@ -593,6 +616,7 @@ setMethod("ncol", #' @param x a SparkSQL DataFrame #' #' @rdname dim +#' @name dim #' @export #' @examples #'\dontrun{ @@ -613,8 +637,8 @@ setMethod("dim", #' @param x A SparkSQL DataFrame #' @param stringsAsFactors (Optional) A logical indicating whether or not string columns #' should be converted to factors. FALSE by default. - -#' @rdname collect-methods +#' @rdname collect +#' @name collect #' @export #' @examples #'\dontrun{ @@ -650,6 +674,7 @@ setMethod("collect", #' @return A new DataFrame containing the number of rows specified. #' #' @rdname limit +#' @name limit #' @export #' @examples #' \dontrun{ @@ -669,6 +694,7 @@ setMethod("limit", #' Take the first NUM rows of a DataFrame and return a the results as a data.frame #' #' @rdname take +#' @name take #' @export #' @examples #'\dontrun{ @@ -696,6 +722,7 @@ setMethod("take", #' @return A data.frame #' #' @rdname head +#' @name head #' @export #' @examples #'\dontrun{ @@ -717,6 +744,7 @@ setMethod("head", #' @param x A SparkSQL DataFrame #' #' @rdname first +#' @name first #' @export #' @examples #'\dontrun{ @@ -732,7 +760,7 @@ setMethod("first", take(x, 1) }) -# toRDD() +# toRDD # # Converts a Spark DataFrame to an RDD while preserving column names. # @@ -769,6 +797,7 @@ setMethod("toRDD", #' @seealso GroupedData #' @aliases group_by #' @rdname groupBy +#' @name groupBy #' @export #' @examples #' \dontrun{ @@ -792,7 +821,7 @@ setMethod("groupBy", }) #' @rdname groupBy -#' @aliases group_by +#' @name group_by setMethod("group_by", signature(x = "DataFrame"), function(x, ...) { @@ -804,7 +833,8 @@ setMethod("group_by", #' Compute aggregates by specifying a list of columns #' #' @param x a DataFrame -#' @rdname DataFrame +#' @rdname agg +#' @name agg #' @aliases summarize #' @export setMethod("agg", @@ -813,8 +843,8 @@ setMethod("agg", agg(groupBy(x), ...) }) -#' @rdname DataFrame -#' @aliases agg +#' @rdname agg +#' @name summarize setMethod("summarize", signature(x = "DataFrame"), function(x, ...) { @@ -890,12 +920,14 @@ getColumn <- function(x, c) { } #' @rdname select +#' @name $ setMethod("$", signature(x = "DataFrame"), function(x, name) { getColumn(x, name) }) #' @rdname select +#' @name $<- setMethod("$<-", signature(x = "DataFrame"), function(x, name, value) { stopifnot(class(value) == "Column" || is.null(value)) @@ -923,6 +955,7 @@ setMethod("$<-", signature(x = "DataFrame"), }) #' @rdname select +#' @name [[ setMethod("[[", signature(x = "DataFrame"), function(x, i) { if (is.numeric(i)) { @@ -933,6 +966,7 @@ setMethod("[[", signature(x = "DataFrame"), }) #' @rdname select +#' @name [ setMethod("[", signature(x = "DataFrame", i = "missing"), function(x, i, j, ...) { if (is.numeric(j)) { @@ -1008,6 +1042,7 @@ setMethod("select", #' @param ... Additional expressions #' @return A DataFrame #' @rdname selectExpr +#' @name selectExpr #' @export #' @examples #'\dontrun{ @@ -1034,6 +1069,8 @@ setMethod("selectExpr", #' @param col A Column expression. #' @return A DataFrame with the new column added. #' @rdname withColumn +#' @name withColumn +#' @aliases mutate #' @export #' @examples #'\dontrun{ @@ -1057,7 +1094,7 @@ setMethod("withColumn", #' @param col a named argument of the form name = col #' @return A new DataFrame with the new columns added. #' @rdname withColumn -#' @aliases withColumn +#' @name mutate #' @export #' @examples #'\dontrun{ @@ -1094,6 +1131,7 @@ setMethod("mutate", #' @param newCol The new column name. #' @return A DataFrame with the column name changed. #' @rdname withColumnRenamed +#' @name withColumnRenamed #' @export #' @examples #'\dontrun{ @@ -1124,6 +1162,7 @@ setMethod("withColumnRenamed", #' @param newCol A named pair of the form new_column_name = existing_column #' @return A DataFrame with the column name changed. #' @rdname withColumnRenamed +#' @name rename #' @aliases withColumnRenamed #' @export #' @examples @@ -1165,6 +1204,8 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' @param ... Additional sorting fields #' @return A DataFrame where all elements are sorted. #' @rdname arrange +#' @name arrange +#' @aliases orderby #' @export #' @examples #'\dontrun{ @@ -1191,7 +1232,7 @@ setMethod("arrange", }) #' @rdname arrange -#' @aliases orderBy,DataFrame,function-method +#' @name orderby setMethod("orderBy", signature(x = "DataFrame", col = "characterOrColumn"), function(x, col) { @@ -1207,6 +1248,7 @@ setMethod("orderBy", #' or a string containing a SQL statement #' @return A DataFrame containing only the rows that meet the condition. #' @rdname filter +#' @name filter #' @export #' @examples #'\dontrun{ @@ -1228,7 +1270,7 @@ setMethod("filter", }) #' @rdname filter -#' @aliases where,DataFrame,function-method +#' @name where setMethod("where", signature(x = "DataFrame", condition = "characterOrColumn"), function(x, condition) { @@ -1247,6 +1289,7 @@ setMethod("where", #' 'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'. The default joinType is "inner". #' @return A DataFrame containing the result of the join operation. #' @rdname join +#' @name join #' @export #' @examples #'\dontrun{ @@ -1279,8 +1322,9 @@ setMethod("join", dataFrame(sdf) }) -#' rdname merge -#' aliases join +#' @rdname merge +#' @name merge +#' @aliases join setMethod("merge", signature(x = "DataFrame", y = "DataFrame"), function(x, y, joinExpr = NULL, joinType = NULL, ...) { @@ -1298,6 +1342,7 @@ setMethod("merge", #' @param y A Spark DataFrame #' @return A DataFrame containing the result of the union. #' @rdname unionAll +#' @name unionAll #' @export #' @examples #'\dontrun{ @@ -1319,6 +1364,7 @@ setMethod("unionAll", #' @description Returns a new DataFrame containing rows of all parameters. # #' @rdname rbind +#' @name rbind #' @aliases unionAll setMethod("rbind", signature(... = "DataFrame"), @@ -1339,6 +1385,7 @@ setMethod("rbind", #' @param y A Spark DataFrame #' @return A DataFrame containing the result of the intersect. #' @rdname intersect +#' @name intersect #' @export #' @examples #'\dontrun{ @@ -1364,6 +1411,7 @@ setMethod("intersect", #' @param y A Spark DataFrame #' @return A DataFrame containing the result of the except operation. #' @rdname except +#' @name except #' @export #' @examples #'\dontrun{ @@ -1403,6 +1451,8 @@ setMethod("except", #' @param mode One of 'append', 'overwrite', 'error', 'ignore' #' #' @rdname write.df +#' @name write.df +#' @aliases saveDF #' @export #' @examples #'\dontrun{ @@ -1435,7 +1485,7 @@ setMethod("write.df", }) #' @rdname write.df -#' @aliases saveDF +#' @name saveDF #' @export setMethod("saveDF", signature(df = "DataFrame", path = "character"), @@ -1466,6 +1516,7 @@ setMethod("saveDF", #' @param mode One of 'append', 'overwrite', 'error', 'ignore' #' #' @rdname saveAsTable +#' @name saveAsTable #' @export #' @examples #'\dontrun{ @@ -1505,6 +1556,8 @@ setMethod("saveAsTable", #' @param ... Additional expressions #' @return A DataFrame #' @rdname describe +#' @name describe +#' @aliases summary #' @export #' @examples #'\dontrun{ @@ -1525,6 +1578,7 @@ setMethod("describe", }) #' @rdname describe +#' @name describe setMethod("describe", signature(x = "DataFrame"), function(x) { @@ -1538,7 +1592,7 @@ setMethod("describe", #' @description Computes statistics for numeric columns of the DataFrame #' #' @rdname summary -#' @aliases describe +#' @name summary setMethod("summary", signature(x = "DataFrame"), function(x) { @@ -1562,6 +1616,8 @@ setMethod("summary", #' @return A DataFrame #' #' @rdname nafunctions +#' @name dropna +#' @aliases na.omit #' @export #' @examples #'\dontrun{ @@ -1588,7 +1644,8 @@ setMethod("dropna", dataFrame(sdf) }) -#' @aliases dropna +#' @rdname nafunctions +#' @name na.omit #' @export setMethod("na.omit", signature(x = "DataFrame"), @@ -1615,6 +1672,7 @@ setMethod("na.omit", #' @return A DataFrame #' #' @rdname nafunctions +#' @name fillna #' @export #' @examples #'\dontrun{ @@ -1685,6 +1743,7 @@ setMethod("fillna", #' occurrences will have zero as their counts. #' #' @rdname statfunctions +#' @name crosstab #' @export #' @examples #' \dontrun{ diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index a1f50c383367c..4805096f3f9c5 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -24,10 +24,9 @@ setOldClass("jobj") #' @title S4 class that represents a DataFrame column #' @description The column class supports unary, binary operations on DataFrame columns - #' @rdname column #' -#' @param jc reference to JVM DataFrame column +#' @slot jc reference to JVM DataFrame column #' @export setClass("Column", slots = list(jc = "jobj")) @@ -46,6 +45,7 @@ col <- function(x) { } #' @rdname show +#' @name show setMethod("show", "Column", function(object) { cat("Column", callJMethod(object@jc, "toString"), "\n") @@ -122,8 +122,11 @@ createMethods() #' alias #' #' Set a new name for a column - -#' @rdname column +#' +#' @rdname alias +#' @name alias +#' @family colum_func +#' @export setMethod("alias", signature(object = "Column"), function(object, data) { @@ -138,7 +141,9 @@ setMethod("alias", #' #' An expression that returns a substring. #' -#' @rdname column +#' @rdname substr +#' @name substr +#' @family colum_func #' #' @param start starting position #' @param stop ending position @@ -152,7 +157,9 @@ setMethod("substr", signature(x = "Column"), #' #' Test if the column is between the lower bound and upper bound, inclusive. #' -#' @rdname column +#' @rdname between +#' @name between +#' @family colum_func #' #' @param bounds lower and upper bounds setMethod("between", signature(x = "Column"), @@ -167,7 +174,9 @@ setMethod("between", signature(x = "Column"), #' Casts the column to a different data type. #' -#' @rdname column +#' @rdname cast +#' @name cast +#' @family colum_func #' #' @examples \dontrun{ #' cast(df$age, "string") @@ -189,11 +198,15 @@ setMethod("cast", #' Match a column with given values. #' -#' @rdname column +#' @rdname match +#' @name %in% +#' @aliases %in% #' @return a matched values as a result of comparing with given values. -#' @examples \dontrun{ -#' filter(df, "age in (10, 30)") -#' where(df, df$age %in% c(10, 30)) +#' @export +#' @examples +#' \dontrun{ +#' filter(df, "age in (10, 30)") +#' where(df, df$age %in% c(10, 30)) #' } setMethod("%in%", signature(x = "Column"), @@ -208,7 +221,10 @@ setMethod("%in%", #' If values in the specified column are null, returns the value. #' Can be used in conjunction with `when` to specify a default value for expressions. #' -#' @rdname column +#' @rdname otherwise +#' @name otherwise +#' @family colum_func +#' @export setMethod("otherwise", signature(x = "Column", value = "ANY"), function(x, value) { diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 610a8c31223cd..a829d46c1894c 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -441,7 +441,7 @@ setGeneric("filter", function(x, condition) { standardGeneric("filter") }) #' @export setGeneric("group_by", function(x, ...) { standardGeneric("group_by") }) -#' @rdname DataFrame +#' @rdname groupBy #' @export setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") }) From 57b960bf3706728513f9e089455a533f0244312e Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 25 Aug 2015 08:32:20 +0100 Subject: [PATCH 064/802] [SPARK-6196] [BUILD] Remove MapR profiles in favor of hadoop-provided Follow up to https://github.com/apache/spark/pull/7047 pwendell mentioned that MapR should use `hadoop-provided` now, and indeed the new build script does not produce `mapr3`/`mapr4` artifacts anymore. Hence the action seems to be to remove the profiles, which are now not used. CC trystanleftwich Author: Sean Owen Closes #8338 from srowen/SPARK-6196. --- pom.xml | 38 -------------------------------------- 1 file changed, 38 deletions(-) diff --git a/pom.xml b/pom.xml index d5945f2546d38..0716016523ee1 100644 --- a/pom.xml +++ b/pom.xml @@ -2386,44 +2386,6 @@ - - mapr3 - - 1.0.3-mapr-3.0.3 - 2.4.1-mapr-1408 - 0.98.4-mapr-1408 - 3.4.5-mapr-1406 - - - - - mapr4 - - 2.4.1-mapr-1408 - 2.4.1-mapr-1408 - 0.98.4-mapr-1408 - 3.4.5-mapr-1406 - - - - org.apache.curator - curator-recipes - ${curator.version} - - - org.apache.zookeeper - zookeeper - - - - - org.apache.zookeeper - zookeeper - 3.4.5-mapr-1406 - - - - hive-thriftserver From 1fc37581a52530bac5d555dbf14927a5780c3b75 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 25 Aug 2015 00:35:51 -0700 Subject: [PATCH 065/802] [SPARK-10210] [STREAMING] Filter out non-existent blocks before creating BlockRDD When write ahead log is not enabled, a recovered streaming driver still tries to run jobs using pre-failure block ids, and fails as the block do not exists in-memory any more (and cannot be recovered as receiver WAL is not enabled). This occurs because the driver-side WAL of ReceivedBlockTracker is recovers that past block information, and ReceiveInputDStream creates BlockRDDs even if those blocks do not exist. The solution in this PR is to filter out block ids that do not exist before creating the BlockRDD. In addition, it adds unit tests to verify other logic in ReceiverInputDStream. Author: Tathagata Das Closes #8405 from tdas/SPARK-10210. --- .../dstream/ReceiverInputDStream.scala | 10 +- .../rdd/WriteAheadLogBackedBlockRDD.scala | 2 +- .../streaming/ReceiverInputDStreamSuite.scala | 156 ++++++++++++++++++ 3 files changed, 166 insertions(+), 2 deletions(-) create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index a15800917c6f4..6c139f32da31d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -116,7 +116,15 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont logWarning("Some blocks have Write Ahead Log information; this is unexpected") } } - new BlockRDD[T](ssc.sc, blockIds) + val validBlockIds = blockIds.filter { id => + ssc.sparkContext.env.blockManager.master.contains(id) + } + if (validBlockIds.size != blockIds.size) { + logWarning("Some blocks could not be recovered as they were not found in memory. " + + "To prevent such data loss, enabled Write Ahead Log (see programming guide " + + "for more details.") + } + new BlockRDD[T](ssc.sc, validBlockIds) } } else { // If no block is ready now, creating WriteAheadLogBackedBlockRDD or BlockRDD diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index 620b8a36a2baf..e081ffe46f502 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -75,7 +75,7 @@ private[streaming] class WriteAheadLogBackedBlockRDD[T: ClassTag]( @transient sc: SparkContext, @transient blockIds: Array[BlockId], - @transient walRecordHandles: Array[WriteAheadLogRecordHandle], + @transient val walRecordHandles: Array[WriteAheadLogRecordHandle], @transient isBlockIdValid: Array[Boolean] = Array.empty, storeInBlockManager: Boolean = false, storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala new file mode 100644 index 0000000000000..6d388d9624d92 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import scala.util.Random + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.rdd.BlockRDD +import org.apache.spark.storage.{StorageLevel, StreamBlockId} +import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD +import org.apache.spark.streaming.receiver.{BlockManagerBasedStoreResult, Receiver, WriteAheadLogBasedStoreResult} +import org.apache.spark.streaming.scheduler.ReceivedBlockInfo +import org.apache.spark.streaming.util.{WriteAheadLogRecordHandle, WriteAheadLogUtils} +import org.apache.spark.{SparkConf, SparkEnv} + +class ReceiverInputDStreamSuite extends TestSuiteBase with BeforeAndAfterAll { + + override def afterAll(): Unit = { + StreamingContext.getActive().map { _.stop() } + } + + testWithoutWAL("createBlockRDD creates empty BlockRDD when no block info") { receiverStream => + val rdd = receiverStream.createBlockRDD(Time(0), Seq.empty) + assert(rdd.isInstanceOf[BlockRDD[_]]) + assert(!rdd.isInstanceOf[WriteAheadLogBackedBlockRDD[_]]) + assert(rdd.isEmpty()) + } + + testWithoutWAL("createBlockRDD creates correct BlockRDD with block info") { receiverStream => + val blockInfos = Seq.fill(5) { createBlockInfo(withWALInfo = false) } + val blockIds = blockInfos.map(_.blockId) + + // Verify that there are some blocks that are present, and some that are not + require(blockIds.forall(blockId => SparkEnv.get.blockManager.master.contains(blockId))) + + val rdd = receiverStream.createBlockRDD(Time(0), blockInfos) + assert(rdd.isInstanceOf[BlockRDD[_]]) + assert(!rdd.isInstanceOf[WriteAheadLogBackedBlockRDD[_]]) + val blockRDD = rdd.asInstanceOf[BlockRDD[_]] + assert(blockRDD.blockIds.toSeq === blockIds) + } + + testWithoutWAL("createBlockRDD filters non-existent blocks before creating BlockRDD") { + receiverStream => + val presentBlockInfos = Seq.fill(2)(createBlockInfo(withWALInfo = false, createBlock = true)) + val absentBlockInfos = Seq.fill(3)(createBlockInfo(withWALInfo = false, createBlock = false)) + val blockInfos = presentBlockInfos ++ absentBlockInfos + val blockIds = blockInfos.map(_.blockId) + + // Verify that there are some blocks that are present, and some that are not + require(blockIds.exists(blockId => SparkEnv.get.blockManager.master.contains(blockId))) + require(blockIds.exists(blockId => !SparkEnv.get.blockManager.master.contains(blockId))) + + val rdd = receiverStream.createBlockRDD(Time(0), blockInfos) + assert(rdd.isInstanceOf[BlockRDD[_]]) + val blockRDD = rdd.asInstanceOf[BlockRDD[_]] + assert(blockRDD.blockIds.toSeq === presentBlockInfos.map { _.blockId}) + } + + testWithWAL("createBlockRDD creates empty WALBackedBlockRDD when no block info") { + receiverStream => + val rdd = receiverStream.createBlockRDD(Time(0), Seq.empty) + assert(rdd.isInstanceOf[WriteAheadLogBackedBlockRDD[_]]) + assert(rdd.isEmpty()) + } + + testWithWAL( + "createBlockRDD creates correct WALBackedBlockRDD with all block info having WAL info") { + receiverStream => + val blockInfos = Seq.fill(5) { createBlockInfo(withWALInfo = true) } + val blockIds = blockInfos.map(_.blockId) + val rdd = receiverStream.createBlockRDD(Time(0), blockInfos) + assert(rdd.isInstanceOf[WriteAheadLogBackedBlockRDD[_]]) + val blockRDD = rdd.asInstanceOf[WriteAheadLogBackedBlockRDD[_]] + assert(blockRDD.blockIds.toSeq === blockIds) + assert(blockRDD.walRecordHandles.toSeq === blockInfos.map { _.walRecordHandleOption.get }) + } + + testWithWAL("createBlockRDD creates BlockRDD when some block info dont have WAL info") { + receiverStream => + val blockInfos1 = Seq.fill(2) { createBlockInfo(withWALInfo = true) } + val blockInfos2 = Seq.fill(3) { createBlockInfo(withWALInfo = false) } + val blockInfos = blockInfos1 ++ blockInfos2 + val blockIds = blockInfos.map(_.blockId) + val rdd = receiverStream.createBlockRDD(Time(0), blockInfos) + assert(rdd.isInstanceOf[BlockRDD[_]]) + val blockRDD = rdd.asInstanceOf[BlockRDD[_]] + assert(blockRDD.blockIds.toSeq === blockIds) + } + + + private def testWithoutWAL(msg: String)(body: ReceiverInputDStream[_] => Unit): Unit = { + test(s"Without WAL enabled: $msg") { + runTest(enableWAL = false, body) + } + } + + private def testWithWAL(msg: String)(body: ReceiverInputDStream[_] => Unit): Unit = { + test(s"With WAL enabled: $msg") { + runTest(enableWAL = true, body) + } + } + + private def runTest(enableWAL: Boolean, body: ReceiverInputDStream[_] => Unit): Unit = { + val conf = new SparkConf() + conf.setMaster("local[4]").setAppName("ReceiverInputDStreamSuite") + conf.set(WriteAheadLogUtils.RECEIVER_WAL_ENABLE_CONF_KEY, enableWAL.toString) + require(WriteAheadLogUtils.enableReceiverLog(conf) === enableWAL) + val ssc = new StreamingContext(conf, Seconds(1)) + val receiverStream = new ReceiverInputDStream[Int](ssc) { + override def getReceiver(): Receiver[Int] = null + } + withStreamingContext(ssc) { ssc => + body(receiverStream) + } + } + + /** + * Create a block info for input to the ReceiverInputDStream.createBlockRDD + * @param withWALInfo Create block with WAL info in it + * @param createBlock Actually create the block in the BlockManager + * @return + */ + private def createBlockInfo( + withWALInfo: Boolean, + createBlock: Boolean = true): ReceivedBlockInfo = { + val blockId = new StreamBlockId(0, Random.nextLong()) + if (createBlock) { + SparkEnv.get.blockManager.putSingle(blockId, 1, StorageLevel.MEMORY_ONLY, tellMaster = true) + require(SparkEnv.get.blockManager.master.contains(blockId)) + } + val storeResult = if (withWALInfo) { + new WriteAheadLogBasedStoreResult(blockId, None, new WriteAheadLogRecordHandle { }) + } else { + new BlockManagerBasedStoreResult(blockId, None) + } + new ReceivedBlockInfo(0, None, None, storeResult) + } +} From 2f493f7e3924b769160a16f73cccbebf21973b91 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 25 Aug 2015 16:00:44 +0800 Subject: [PATCH 066/802] [SPARK-10177] [SQL] fix reading Timestamp in parquet from Hive We misunderstood the Julian days and nanoseconds of the day in parquet (as TimestampType) from Hive/Impala, they are overlapped, so can't be added together directly. In order to avoid the confusing rounding when do the converting, we use `2440588` as the Julian Day of epoch of unix timestamp (which should be 2440587.5). Author: Davies Liu Author: Cheng Lian Closes #8400 from davies/timestamp_parquet. --- .../spark/sql/catalyst/util/DateTimeUtils.scala | 7 ++++--- .../sql/catalyst/util/DateTimeUtilsSuite.scala | 13 +++++++++---- .../sql/hive/ParquetHiveCompatibilitySuite.scala | 2 +- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 672620460c3c5..d652fce3fd9b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -37,7 +37,8 @@ object DateTimeUtils { type SQLTimestamp = Long // see http://stackoverflow.com/questions/466321/convert-unix-timestamp-to-julian - final val JULIAN_DAY_OF_EPOCH = 2440587 // and .5 + // it's 2440587.5, rounding up to compatible with Hive + final val JULIAN_DAY_OF_EPOCH = 2440588 final val SECONDS_PER_DAY = 60 * 60 * 24L final val MICROS_PER_SECOND = 1000L * 1000L final val NANOS_PER_SECOND = MICROS_PER_SECOND * 1000L @@ -183,7 +184,7 @@ object DateTimeUtils { */ def fromJulianDay(day: Int, nanoseconds: Long): SQLTimestamp = { // use Long to avoid rounding errors - val seconds = (day - JULIAN_DAY_OF_EPOCH).toLong * SECONDS_PER_DAY - SECONDS_PER_DAY / 2 + val seconds = (day - JULIAN_DAY_OF_EPOCH).toLong * SECONDS_PER_DAY seconds * MICROS_PER_SECOND + nanoseconds / 1000L } @@ -191,7 +192,7 @@ object DateTimeUtils { * Returns Julian day and nanoseconds in a day from the number of microseconds */ def toJulianDay(us: SQLTimestamp): (Int, Long) = { - val seconds = us / MICROS_PER_SECOND + SECONDS_PER_DAY / 2 + val seconds = us / MICROS_PER_SECOND val day = seconds / SECONDS_PER_DAY + JULIAN_DAY_OF_EPOCH val secondsInDay = seconds % SECONDS_PER_DAY val nanos = (us % MICROS_PER_SECOND) * 1000L diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index d18fa4df13355..1596bb79fa94b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -49,13 +49,18 @@ class DateTimeUtilsSuite extends SparkFunSuite { test("us and julian day") { val (d, ns) = toJulianDay(0) assert(d === JULIAN_DAY_OF_EPOCH) - assert(ns === SECONDS_PER_DAY / 2 * NANOS_PER_SECOND) + assert(ns === 0) assert(fromJulianDay(d, ns) == 0L) - val t = new Timestamp(61394778610000L) // (2015, 6, 11, 10, 10, 10, 100) + val t = Timestamp.valueOf("2015-06-11 10:10:10.100") val (d1, ns1) = toJulianDay(fromJavaTimestamp(t)) - val t2 = toJavaTimestamp(fromJulianDay(d1, ns1)) - assert(t.equals(t2)) + val t1 = toJavaTimestamp(fromJulianDay(d1, ns1)) + assert(t.equals(t1)) + + val t2 = Timestamp.valueOf("2015-06-11 20:10:10.100") + val (d2, ns2) = toJulianDay(fromJavaTimestamp(t2)) + val t22 = toJavaTimestamp(fromJulianDay(d2, ns2)) + assert(t2.equals(t22)) } test("SPARK-6785: java date conversion before and after epoch") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index bc30180cf0917..91d7a48208e8d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -113,7 +113,7 @@ class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with Before "BOOLEAN", "TINYINT", "SMALLINT", "INT", "BIGINT", "FLOAT", "DOUBLE", "STRING") } - ignore("SPARK-10177 timestamp") { + test("SPARK-10177 timestamp") { testParquetHiveCompatibility(Row(Timestamp.valueOf("2015-08-24 00:31:00")), "TIMESTAMP") } From 7bc9a8c6249300ded31ea931c463d0a8f798e193 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 25 Aug 2015 01:06:36 -0700 Subject: [PATCH 067/802] [SPARK-10195] [SQL] Data sources Filter should not expose internal types Spark SQL's data sources API exposes Catalyst's internal types through its Filter interfaces. This is a problem because types like UTF8String are not stable developer APIs and should not be exposed to third-parties. This issue caused incompatibilities when upgrading our `spark-redshift` library to work against Spark 1.5.0. To avoid these issues in the future we should only expose public types through these Filter objects. This patch accomplishes this by using CatalystTypeConverters to add the appropriate conversions. Author: Josh Rosen Closes #8403 from JoshRosen/datasources-internal-vs-external-types. --- .../datasources/DataSourceStrategy.scala | 67 ++++++++++--------- .../execution/datasources/jdbc/JDBCRDD.scala | 2 +- .../datasources/parquet/ParquetFilters.scala | 19 +++--- .../spark/sql/sources/FilteredScanSuite.scala | 7 ++ 4 files changed, 54 insertions(+), 41 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2a4c40db8bb66..6c1ef6a6df887 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.{Logging, TaskContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} -import org.apache.spark.sql.catalyst.{InternalRow, expressions} +import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical @@ -344,45 +345,47 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { */ protected[sql] def selectFilters(filters: Seq[Expression]) = { def translate(predicate: Expression): Option[Filter] = predicate match { - case expressions.EqualTo(a: Attribute, Literal(v, _)) => - Some(sources.EqualTo(a.name, v)) - case expressions.EqualTo(Literal(v, _), a: Attribute) => - Some(sources.EqualTo(a.name, v)) - - case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) => - Some(sources.EqualNullSafe(a.name, v)) - case expressions.EqualNullSafe(Literal(v, _), a: Attribute) => - Some(sources.EqualNullSafe(a.name, v)) - - case expressions.GreaterThan(a: Attribute, Literal(v, _)) => - Some(sources.GreaterThan(a.name, v)) - case expressions.GreaterThan(Literal(v, _), a: Attribute) => - Some(sources.LessThan(a.name, v)) - - case expressions.LessThan(a: Attribute, Literal(v, _)) => - Some(sources.LessThan(a.name, v)) - case expressions.LessThan(Literal(v, _), a: Attribute) => - Some(sources.GreaterThan(a.name, v)) - - case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) => - Some(sources.GreaterThanOrEqual(a.name, v)) - case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) => - Some(sources.LessThanOrEqual(a.name, v)) - - case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => - Some(sources.LessThanOrEqual(a.name, v)) - case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => - Some(sources.GreaterThanOrEqual(a.name, v)) + case expressions.EqualTo(a: Attribute, Literal(v, t)) => + Some(sources.EqualTo(a.name, convertToScala(v, t))) + case expressions.EqualTo(Literal(v, t), a: Attribute) => + Some(sources.EqualTo(a.name, convertToScala(v, t))) + + case expressions.EqualNullSafe(a: Attribute, Literal(v, t)) => + Some(sources.EqualNullSafe(a.name, convertToScala(v, t))) + case expressions.EqualNullSafe(Literal(v, t), a: Attribute) => + Some(sources.EqualNullSafe(a.name, convertToScala(v, t))) + + case expressions.GreaterThan(a: Attribute, Literal(v, t)) => + Some(sources.GreaterThan(a.name, convertToScala(v, t))) + case expressions.GreaterThan(Literal(v, t), a: Attribute) => + Some(sources.LessThan(a.name, convertToScala(v, t))) + + case expressions.LessThan(a: Attribute, Literal(v, t)) => + Some(sources.LessThan(a.name, convertToScala(v, t))) + case expressions.LessThan(Literal(v, t), a: Attribute) => + Some(sources.GreaterThan(a.name, convertToScala(v, t))) + + case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, t)) => + Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t))) + case expressions.GreaterThanOrEqual(Literal(v, t), a: Attribute) => + Some(sources.LessThanOrEqual(a.name, convertToScala(v, t))) + + case expressions.LessThanOrEqual(a: Attribute, Literal(v, t)) => + Some(sources.LessThanOrEqual(a.name, convertToScala(v, t))) + case expressions.LessThanOrEqual(Literal(v, t), a: Attribute) => + Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t))) case expressions.InSet(a: Attribute, set) => - Some(sources.In(a.name, set.toArray)) + val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) + Some(sources.In(a.name, set.toArray.map(toScala))) // Because we only convert In to InSet in Optimizer when there are more than certain // items. So it is possible we still get an In expression here that needs to be pushed // down. case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) => val hSet = list.map(e => e.eval(EmptyRow)) - Some(sources.In(a.name, hSet.toArray)) + val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) + Some(sources.In(a.name, hSet.toArray.map(toScala))) case expressions.IsNull(a: Attribute) => Some(sources.IsNull(a.name)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index e537d631f4559..730d88b024cb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -262,7 +262,7 @@ private[sql] class JDBCRDD( * Converts value to SQL expression. */ private def compileValue(value: Any): Any = value match { - case stringValue: UTF8String => s"'${escapeSql(stringValue.toString)}'" + case stringValue: String => s"'${escapeSql(stringValue)}'" case _ => value } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index c74c8388632f5..c6b3fe7900da8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -32,7 +32,6 @@ import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.sources import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String private[sql] object ParquetFilters { val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter" @@ -65,7 +64,7 @@ private[sql] object ParquetFilters { case StringType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), - Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) + Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8"))).orNull) case BinaryType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), @@ -86,7 +85,7 @@ private[sql] object ParquetFilters { case StringType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), - Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) + Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8"))).orNull) case BinaryType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), @@ -104,7 +103,8 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) + FilterApi.lt(binaryColumn(n), + Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8"))) case BinaryType => (n: String, v: Any) => FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -121,7 +121,8 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) + FilterApi.ltEq(binaryColumn(n), + Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8"))) case BinaryType => (n: String, v: Any) => FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -138,7 +139,8 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) + FilterApi.gt(binaryColumn(n), + Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8"))) case BinaryType => (n: String, v: Any) => FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -155,7 +157,8 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) + FilterApi.gtEq(binaryColumn(n), + Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8"))) case BinaryType => (n: String, v: Any) => FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -177,7 +180,7 @@ private[sql] object ParquetFilters { case StringType => (n: String, v: Set[Any]) => FilterApi.userDefined(binaryColumn(n), - SetInFilter(v.map(e => Binary.fromByteArray(e.asInstanceOf[UTF8String].getBytes)))) + SetInFilter(v.map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8"))))) case BinaryType => (n: String, v: Set[Any]) => FilterApi.userDefined(binaryColumn(n), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index c81c3d3982805..68ce37c00077e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.sources import scala.language.existentials import org.apache.spark.rdd.RDD +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.sql._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -78,6 +79,9 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL case StringStartsWith("c", v) => _.startsWith(v) case StringEndsWith("c", v) => _.endsWith(v) case StringContains("c", v) => _.contains(v) + case EqualTo("c", v: String) => _.equals(v) + case EqualTo("c", v: UTF8String) => sys.error("UTF8String should not appear in filters") + case In("c", values) => (s: String) => values.map(_.asInstanceOf[String]).toSet.contains(s) case _ => (c: String) => true } @@ -237,6 +241,9 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext { testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", 1) testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%Ee%'", 0) + testPushDown("SELECT c FROM oneToTenFiltered WHERE c = 'aaaaaAAAAA'", 1) + testPushDown("SELECT c FROM oneToTenFiltered WHERE c IN ('aaaaaAAAAA', 'foo')", 1) + def testPushDown(sqlString: String, expectedCount: Int): Unit = { test(s"PushDown Returns $expectedCount: $sqlString") { val queryExecution = sql(sqlString).queryExecution From 0e6368ffaec1965d0c7f89420e04a974675c7f6e Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 25 Aug 2015 16:19:34 +0800 Subject: [PATCH 068/802] [SPARK-10197] [SQL] Add null check in wrapperFor (inside HiveInspectors). https://issues.apache.org/jira/browse/SPARK-10197 Author: Yin Huai Closes #8407 from yhuai/ORCSPARK-10197. --- .../spark/sql/hive/HiveInspectors.scala | 29 +++++++++++++++---- .../spark/sql/hive/orc/OrcSourceSuite.scala | 29 +++++++++++++++++++ 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 9824dad239596..64fffdbf9b020 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -370,17 +370,36 @@ private[hive] trait HiveInspectors { protected def wrapperFor(oi: ObjectInspector, dataType: DataType): Any => Any = oi match { case _: JavaHiveVarcharObjectInspector => (o: Any) => - val s = o.asInstanceOf[UTF8String].toString - new HiveVarchar(s, s.size) + if (o != null) { + val s = o.asInstanceOf[UTF8String].toString + new HiveVarchar(s, s.size) + } else { + null + } case _: JavaHiveDecimalObjectInspector => - (o: Any) => HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal) + (o: Any) => + if (o != null) { + HiveDecimal.create(o.asInstanceOf[Decimal].toJavaBigDecimal) + } else { + null + } case _: JavaDateObjectInspector => - (o: Any) => DateTimeUtils.toJavaDate(o.asInstanceOf[Int]) + (o: Any) => + if (o != null) { + DateTimeUtils.toJavaDate(o.asInstanceOf[Int]) + } else { + null + } case _: JavaTimestampObjectInspector => - (o: Any) => DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long]) + (o: Any) => + if (o != null) { + DateTimeUtils.toJavaTimestamp(o.asInstanceOf[Long]) + } else { + null + } case soi: StandardStructObjectInspector => val schema = dataType.asInstanceOf[StructType] diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 82e08caf46457..80c38084f293d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -121,6 +121,35 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { sql("SELECT * FROM normal_orc_as_source"), (6 to 10).map(i => Row(i, s"part-$i"))) } + + test("write null values") { + sql("DROP TABLE IF EXISTS orcNullValues") + + val df = sql( + """ + |SELECT + | CAST(null as TINYINT), + | CAST(null as SMALLINT), + | CAST(null as INT), + | CAST(null as BIGINT), + | CAST(null as FLOAT), + | CAST(null as DOUBLE), + | CAST(null as DECIMAL(7,2)), + | CAST(null as TIMESTAMP), + | CAST(null as DATE), + | CAST(null as STRING), + | CAST(null as VARCHAR(10)) + |FROM orc_temp_table limit 1 + """.stripMargin) + + df.write.format("orc").saveAsTable("orcNullValues") + + checkAnswer( + sql("SELECT * FROM orcNullValues"), + Row.fromSeq(Seq.fill(11)(null))) + + sql("DROP TABLE IF EXISTS orcNullValues") + } } class OrcSourceSuite extends OrcSuite { From 5c14890159a5711072bf395f662b2433a389edf9 Mon Sep 17 00:00:00 2001 From: "Zhang, Liye" Date: Tue, 25 Aug 2015 11:48:55 +0100 Subject: [PATCH 069/802] [DOC] add missing parameters in SparkContext.scala for scala doc Author: Zhang, Liye Closes #8412 from liyezhang556520/minorDoc. --- .../scala/org/apache/spark/SparkContext.scala | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 1ddaca8a5ba8c..9849aff85d72e 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -114,6 +114,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * :: DeveloperApi :: * Alternative constructor for setting preferred locations where Spark will create executors. * + * @param config a [[org.apache.spark.SparkConf]] object specifying other Spark parameters * @param preferredNodeLocationData used in YARN mode to select nodes to launch containers on. * Can be generated using [[org.apache.spark.scheduler.InputFormatInfo.computePreferredLocations]] * from a list of input files or InputFormats for the application. @@ -145,6 +146,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @param jars Collection of JARs to send to the cluster. These can be paths on the local file * system or HDFS, HTTP, HTTPS, or FTP URLs. * @param environment Environment variables to set on worker nodes. + * @param preferredNodeLocationData used in YARN mode to select nodes to launch containers on. + * Can be generated using [[org.apache.spark.scheduler.InputFormatInfo.computePreferredLocations]] + * from a list of input files or InputFormats for the application. */ def this( master: String, @@ -841,6 +845,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @note Small files are preferred, large file is also allowable, but may cause bad performance. * @note On some filesystems, `.../path/*` can be a more efficient way to read all files * in a directory rather than `.../path/` or `.../path` + * + * @param path Directory to the input data files, the path can be comma separated paths as the + * list of inputs. * @param minPartitions A suggestion value of the minimal splitting number for input data. */ def wholeTextFiles( @@ -889,6 +896,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @note Small files are preferred; very large files may cause bad performance. * @note On some filesystems, `.../path/*` can be a more efficient way to read all files * in a directory rather than `.../path/` or `.../path` + * + * @param path Directory to the input data files, the path can be comma separated paths as the + * list of inputs. * @param minPartitions A suggestion value of the minimal splitting number for input data. */ @Experimental @@ -918,8 +928,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * '''Note:''' We ensure that the byte array for each record in the resulting RDD * has the provided record length. * - * @param path Directory to the input data files + * @param path Directory to the input data files, the path can be comma separated paths as the + * list of inputs. * @param recordLength The length at which to split the records + * @param conf Configuration for setting up the dataset. + * * @return An RDD of data with values, represented as byte arrays */ @Experimental From 7f1e507bf7e82bff323c5dec3c1ee044687c4173 Mon Sep 17 00:00:00 2001 From: ehnalis Date: Tue, 25 Aug 2015 12:30:06 +0100 Subject: [PATCH 070/802] Fixed a typo in DAGScheduler. Author: ehnalis Closes #8308 from ehnalis/master. --- .../apache/spark/scheduler/DAGScheduler.scala | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 684db6646765f..daf9b0f95273e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -152,17 +152,24 @@ class DAGScheduler( // may lead to more delay in scheduling if those locations are busy. private[scheduler] val REDUCER_PREF_LOCS_FRACTION = 0.2 - // Called by TaskScheduler to report task's starting. + /** + * Called by the TaskSetManager to report task's starting. + */ def taskStarted(task: Task[_], taskInfo: TaskInfo) { eventProcessLoop.post(BeginEvent(task, taskInfo)) } - // Called to report that a task has completed and results are being fetched remotely. + /** + * Called by the TaskSetManager to report that a task has completed + * and results are being fetched remotely. + */ def taskGettingResult(taskInfo: TaskInfo) { eventProcessLoop.post(GettingResultEvent(taskInfo)) } - // Called by TaskScheduler to report task completions or failures. + /** + * Called by the TaskSetManager to report task completions or failures. + */ def taskEnded( task: Task[_], reason: TaskEndReason, @@ -188,18 +195,24 @@ class DAGScheduler( BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat")) } - // Called by TaskScheduler when an executor fails. + /** + * Called by TaskScheduler implementation when an executor fails. + */ def executorLost(execId: String): Unit = { eventProcessLoop.post(ExecutorLost(execId)) } - // Called by TaskScheduler when a host is added + /** + * Called by TaskScheduler implementation when a host is added. + */ def executorAdded(execId: String, host: String): Unit = { eventProcessLoop.post(ExecutorAdded(execId, host)) } - // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or - // cancellation of the job itself. + /** + * Called by the TaskSetManager to cancel an entire TaskSet due to either repeated failures or + * cancellation of the job itself. + */ def taskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]): Unit = { eventProcessLoop.post(TaskSetFailed(taskSet, reason, exception)) } From 69c9c177160e32a2fbc9b36ecc52156077fca6fc Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 25 Aug 2015 12:33:13 +0100 Subject: [PATCH 071/802] [SPARK-9613] [CORE] Ban use of JavaConversions and migrate all existing uses to JavaConverters Replace `JavaConversions` implicits with `JavaConverters` Most occurrences I've seen so far are necessary conversions; a few have been avoidable. None are in critical code as far as I see, yet. Author: Sean Owen Closes #8033 from srowen/SPARK-9613. --- .../shuffle/unsafe/UnsafeShuffleWriter.java | 4 +- .../org/apache/spark/MapOutputTracker.scala | 4 +- .../scala/org/apache/spark/SSLOptions.scala | 11 +- .../scala/org/apache/spark/SparkContext.scala | 4 +- .../scala/org/apache/spark/TestUtils.scala | 9 +- .../apache/spark/api/java/JavaHadoopRDD.scala | 4 +- .../spark/api/java/JavaNewHadoopRDD.scala | 4 +- .../apache/spark/api/java/JavaPairRDD.scala | 19 ++- .../apache/spark/api/java/JavaRDDLike.scala | 75 +++++------- .../spark/api/java/JavaSparkContext.scala | 20 ++-- .../spark/api/python/PythonHadoopUtil.scala | 28 ++--- .../apache/spark/api/python/PythonRDD.scala | 26 ++--- .../apache/spark/api/python/PythonUtils.scala | 15 ++- .../api/python/PythonWorkerFactory.scala | 11 +- .../apache/spark/api/python/SerDeUtil.scala | 3 +- .../WriteInputFormatTestDataGenerator.scala | 8 +- .../scala/org/apache/spark/api/r/RRDD.scala | 13 ++- .../scala/org/apache/spark/api/r/RUtils.scala | 5 +- .../scala/org/apache/spark/api/r/SerDe.scala | 4 +- .../spark/broadcast/TorrentBroadcast.scala | 4 +- .../spark/deploy/ExternalShuffleService.scala | 8 +- .../apache/spark/deploy/PythonRunner.scala | 4 +- .../apache/spark/deploy/RPackageUtils.scala | 4 +- .../org/apache/spark/deploy/RRunner.scala | 4 +- .../spark/deploy/SparkCuratorUtil.scala | 4 +- .../apache/spark/deploy/SparkHadoopUtil.scala | 19 +-- .../spark/deploy/SparkSubmitArguments.scala | 6 +- .../master/ZooKeeperPersistenceEngine.scala | 6 +- .../spark/deploy/worker/CommandUtils.scala | 5 +- .../spark/deploy/worker/DriverRunner.scala | 8 +- .../spark/deploy/worker/ExecutorRunner.scala | 7 +- .../apache/spark/deploy/worker/Worker.scala | 1 - .../org/apache/spark/executor/Executor.scala | 6 +- .../spark/executor/ExecutorSource.scala | 4 +- .../spark/executor/MesosExecutorBackend.scala | 6 +- .../spark/input/PortableDataStream.scala | 11 +- .../input/WholeTextFileInputFormat.scala | 8 +- .../spark/launcher/WorkerCommandBuilder.scala | 4 +- .../apache/spark/metrics/MetricsConfig.scala | 22 ++-- .../network/netty/NettyBlockRpcServer.scala | 4 +- .../netty/NettyBlockTransferService.scala | 6 +- .../apache/spark/network/nio/Connection.scala | 4 +- .../spark/partial/GroupedCountEvaluator.scala | 10 +- .../spark/partial/GroupedMeanEvaluator.scala | 10 +- .../spark/partial/GroupedSumEvaluator.scala | 10 +- .../apache/spark/rdd/PairRDDFunctions.scala | 6 +- .../scala/org/apache/spark/rdd/PipedRDD.scala | 6 +- .../org/apache/spark/rdd/SubtractedRDD.scala | 4 +- .../spark/scheduler/InputFormatInfo.scala | 4 +- .../org/apache/spark/scheduler/Pool.scala | 10 +- .../mesos/CoarseMesosSchedulerBackend.scala | 20 ++-- .../mesos/MesosClusterPersistenceEngine.scala | 4 +- .../cluster/mesos/MesosClusterScheduler.scala | 14 +-- .../cluster/mesos/MesosSchedulerBackend.scala | 22 ++-- .../cluster/mesos/MesosSchedulerUtils.scala | 25 ++-- .../spark/serializer/KryoSerializer.scala | 10 +- .../shuffle/FileShuffleBlockResolver.scala | 8 +- .../storage/BlockManagerMasterEndpoint.scala | 8 +- .../org/apache/spark/util/AkkaUtils.scala | 4 +- .../org/apache/spark/util/ListenerBus.scala | 7 +- .../spark/util/MutableURLClassLoader.scala | 2 - .../spark/util/TimeStampedHashMap.scala | 10 +- .../spark/util/TimeStampedHashSet.scala | 4 +- .../scala/org/apache/spark/util/Utils.scala | 20 ++-- .../apache/spark/util/collection/Utils.scala | 4 +- .../java/org/apache/spark/JavaAPISuite.java | 6 +- .../org/apache/spark/SparkConfSuite.scala | 7 +- .../spark/deploy/LogUrlsStandaloneSuite.scala | 1 - .../spark/deploy/RPackageUtilsSuite.scala | 8 +- .../deploy/worker/ExecutorRunnerTest.scala | 5 +- .../spark/scheduler/SparkListenerSuite.scala | 9 +- .../mesos/MesosSchedulerBackendSuite.scala | 15 +-- .../serializer/KryoSerializerSuite.scala | 3 +- .../org/apache/spark/ui/UISeleniumSuite.scala | 21 ++-- .../spark/examples/CassandraCQLTest.scala | 15 +-- .../apache/spark/examples/CassandraTest.scala | 6 +- .../spark/examples/DriverSubmissionTest.scala | 6 +- .../pythonconverters/AvroConverters.scala | 16 +-- .../CassandraConverters.scala | 14 ++- .../pythonconverters/HBaseConverters.scala | 5 +- .../streaming/flume/sink/SparkSinkSuite.scala | 4 +- .../streaming/flume/EventTransformer.scala | 4 +- .../streaming/flume/FlumeBatchFetcher.scala | 3 +- .../streaming/flume/FlumeInputDStream.scala | 7 +- .../flume/FlumePollingInputDStream.scala | 6 +- .../streaming/flume/FlumeTestUtils.scala | 10 +- .../spark/streaming/flume/FlumeUtils.scala | 8 +- .../flume/PollingFlumeTestUtils.scala | 16 ++- .../flume/FlumePollingStreamSuite.scala | 8 +- .../streaming/flume/FlumeStreamSuite.scala | 2 +- .../streaming/kafka/KafkaTestUtils.scala | 4 +- .../spark/streaming/kafka/KafkaUtils.scala | 35 +++--- .../spark/streaming/zeromq/ZeroMQUtils.scala | 15 ++- .../kinesis/KinesisBackedBlockRDD.scala | 4 +- .../streaming/kinesis/KinesisReceiver.scala | 4 +- .../streaming/kinesis/KinesisTestUtils.scala | 3 +- .../kinesis/KinesisReceiverSuite.scala | 12 +- .../mllib/util/LinearDataGenerator.scala | 4 +- .../ml/classification/JavaOneVsRestSuite.java | 7 +- .../LogisticRegressionSuite.scala | 4 +- .../spark/mllib/classification/SVMSuite.scala | 4 +- .../optimization/GradientDescentSuite.scala | 4 +- .../spark/mllib/recommendation/ALSSuite.scala | 4 +- project/SparkBuild.scala | 8 +- python/pyspark/sql/column.py | 12 ++ python/pyspark/sql/dataframe.py | 4 +- scalastyle-config.xml | 7 ++ .../main/scala/org/apache/spark/sql/Row.scala | 12 +- .../spark/sql/catalyst/analysis/Catalog.scala | 4 +- .../spark/sql/DataFrameNaFunctions.scala | 8 +- .../apache/spark/sql/DataFrameReader.scala | 4 +- .../apache/spark/sql/DataFrameWriter.scala | 4 +- .../org/apache/spark/sql/GroupedData.scala | 4 +- .../scala/org/apache/spark/sql/SQLConf.scala | 13 ++- .../org/apache/spark/sql/SQLContext.scala | 8 +- .../datasources/ResolvedDataSource.scala | 4 +- .../parquet/CatalystReadSupport.scala | 8 +- .../parquet/CatalystRowConverter.scala | 4 +- .../parquet/CatalystSchemaConverter.scala | 4 +- .../datasources/parquet/ParquetRelation.scala | 13 ++- .../parquet/ParquetTypesConverter.scala | 4 +- .../joins/ShuffledHashOuterJoin.scala | 6 +- .../spark/sql/execution/pythonUDFs.scala | 11 +- .../apache/spark/sql/JavaDataFrameSuite.java | 8 +- .../spark/sql/DataFrameNaFunctionsSuite.scala | 6 +- .../org/apache/spark/sql/QueryTest.scala | 4 +- .../ParquetAvroCompatibilitySuite.scala | 3 +- .../parquet/ParquetCompatibilityTest.scala | 7 +- .../datasources/parquet/ParquetIOSuite.scala | 25 ++-- .../SparkExecuteStatementOperation.scala | 10 +- .../hive/thriftserver/SparkSQLCLIDriver.scala | 16 +-- .../thriftserver/SparkSQLCLIService.scala | 6 +- .../hive/thriftserver/SparkSQLDriver.scala | 14 +-- .../sql/hive/thriftserver/SparkSQLEnv.scala | 4 +- .../apache/spark/sql/hive/HiveContext.scala | 4 +- .../spark/sql/hive/HiveInspectors.scala | 40 +++---- .../spark/sql/hive/HiveMetastoreCatalog.scala | 12 +- .../org/apache/spark/sql/hive/HiveQl.scala | 110 ++++++++++-------- .../org/apache/spark/sql/hive/HiveShim.scala | 5 +- .../spark/sql/hive/client/ClientWrapper.scala | 27 ++--- .../spark/sql/hive/client/HiveShim.scala | 14 +-- .../execution/DescribeHiveTableCommand.scala | 8 +- .../sql/hive/execution/HiveTableScan.scala | 9 +- .../hive/execution/InsertIntoHiveTable.scala | 12 +- .../hive/execution/ScriptTransformation.scala | 12 +- .../org/apache/spark/sql/hive/hiveUDFs.scala | 9 +- .../spark/sql/hive/orc/OrcRelation.scala | 8 +- .../apache/spark/sql/hive/test/TestHive.scala | 11 +- .../spark/sql/hive/client/FiltersSuite.scala | 4 +- .../sql/hive/execution/HiveUDFSuite.scala | 29 +++-- .../sql/hive/execution/PruningSuite.scala | 7 +- .../sql/hive/execution/SQLQuerySuite.scala | 4 +- .../sql/sources/hadoopFsRelationSuites.scala | 8 +- .../streaming/api/java/JavaDStreamLike.scala | 12 +- .../streaming/api/java/JavaPairDStream.scala | 28 ++--- .../api/java/JavaStreamingContext.scala | 32 ++--- .../streaming/api/python/PythonDStream.scala | 5 +- .../spark/streaming/receiver/Receiver.scala | 6 +- .../streaming/scheduler/JobScheduler.scala | 4 +- .../scheduler/ReceivedBlockTracker.scala | 4 +- .../util/FileBasedWriteAheadLog.scala | 4 +- .../spark/streaming/JavaTestUtils.scala | 24 ++-- .../streaming/util/WriteAheadLogSuite.scala | 4 +- .../spark/tools/GenerateMIMAIgnore.scala | 6 +- .../org/apache/spark/deploy/yarn/Client.scala | 13 ++- .../spark/deploy/yarn/ExecutorRunnable.scala | 24 ++-- .../spark/deploy/yarn/YarnAllocator.scala | 19 ++- .../spark/deploy/yarn/YarnRMClient.scala | 8 +- .../deploy/yarn/BaseYarnClusterSuite.scala | 6 +- .../spark/deploy/yarn/ClientSuite.scala | 8 +- .../spark/deploy/yarn/YarnClusterSuite.scala | 5 +- 171 files changed, 863 insertions(+), 880 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 2389c28b28395..fdb309e365f69 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -24,7 +24,7 @@ import scala.Option; import scala.Product2; -import scala.collection.JavaConversions; +import scala.collection.JavaConverters; import scala.collection.immutable.Map; import scala.reflect.ClassTag; import scala.reflect.ClassTag$; @@ -160,7 +160,7 @@ public long getPeakMemoryUsedBytes() { */ @VisibleForTesting public void write(Iterator> records) throws IOException { - write(JavaConversions.asScalaIterator(records)); + write(JavaConverters.asScalaIteratorConverter(records).asScala()); } @Override diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 92218832d256f..a387592783850 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -21,8 +21,8 @@ import java.io._ import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} -import scala.collection.JavaConversions._ import scala.reflect.ClassTag import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, RpcEndpoint} @@ -398,7 +398,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) */ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) { protected val mapStatuses: Map[Int, Array[MapStatus]] = - new ConcurrentHashMap[Int, Array[MapStatus]] + new ConcurrentHashMap[Int, Array[MapStatus]]().asScala } private[spark] object MapOutputTracker extends Logging { diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 32df42d57dbd6..3b9c885bf97a7 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -17,9 +17,11 @@ package org.apache.spark -import java.io.{File, FileInputStream} -import java.security.{KeyStore, NoSuchAlgorithmException} -import javax.net.ssl.{KeyManager, KeyManagerFactory, SSLContext, TrustManager, TrustManagerFactory} +import java.io.File +import java.security.NoSuchAlgorithmException +import javax.net.ssl.SSLContext + +import scala.collection.JavaConverters._ import com.typesafe.config.{Config, ConfigFactory, ConfigValueFactory} import org.eclipse.jetty.util.ssl.SslContextFactory @@ -79,7 +81,6 @@ private[spark] case class SSLOptions( * object. It can be used then to compose the ultimate Akka configuration. */ def createAkkaConfig: Option[Config] = { - import scala.collection.JavaConversions._ if (enabled) { Some(ConfigFactory.empty() .withValue("akka.remote.netty.tcp.security.key-store", @@ -97,7 +98,7 @@ private[spark] case class SSLOptions( .withValue("akka.remote.netty.tcp.security.protocol", ConfigValueFactory.fromAnyRef(protocol.getOrElse(""))) .withValue("akka.remote.netty.tcp.security.enabled-algorithms", - ConfigValueFactory.fromIterable(supportedAlgorithms.toSeq)) + ConfigValueFactory.fromIterable(supportedAlgorithms.asJava)) .withValue("akka.remote.netty.tcp.enable-ssl", ConfigValueFactory.fromAnyRef(true))) } else { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 9849aff85d72e..f3da04a7f55d0 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -26,8 +26,8 @@ import java.util.{Arrays, Properties, UUID} import java.util.concurrent.atomic.{AtomicReference, AtomicBoolean, AtomicInteger} import java.util.UUID.randomUUID +import scala.collection.JavaConverters._ import scala.collection.{Map, Set} -import scala.collection.JavaConversions._ import scala.collection.generic.Growable import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} @@ -1546,7 +1546,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def getAllPools: Seq[Schedulable] = { assertNotStopped() // TODO(xiajunluan): We should take nested pools into account - taskScheduler.rootPool.schedulableQueue.toSeq + taskScheduler.rootPool.schedulableQueue.asScala.toSeq } /** diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index a1ebbecf93b7b..888763a3e8ebf 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -19,11 +19,12 @@ package org.apache.spark import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream} import java.net.{URI, URL} +import java.nio.charset.StandardCharsets +import java.util.Arrays import java.util.jar.{JarEntry, JarOutputStream} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ -import com.google.common.base.Charsets.UTF_8 import com.google.common.io.{ByteStreams, Files} import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} @@ -71,7 +72,7 @@ private[spark] object TestUtils { files.foreach { case (k, v) => val entry = new JarEntry(k) jarStream.putNextEntry(entry) - ByteStreams.copy(new ByteArrayInputStream(v.getBytes(UTF_8)), jarStream) + ByteStreams.copy(new ByteArrayInputStream(v.getBytes(StandardCharsets.UTF_8)), jarStream) } jarStream.close() jarFile.toURI.toURL @@ -125,7 +126,7 @@ private[spark] object TestUtils { } else { Seq() } - compiler.getTask(null, null, null, options, null, Seq(sourceFile)).call() + compiler.getTask(null, null, null, options.asJava, null, Arrays.asList(sourceFile)).call() val fileName = className + ".class" val result = new File(fileName) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala index 0ae0b4ec042e2..891bcddeac286 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaHadoopRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.java -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.hadoop.mapred.InputSplit @@ -37,7 +37,7 @@ class JavaHadoopRDD[K, V](rdd: HadoopRDD[K, V]) def mapPartitionsWithInputSplit[R]( f: JFunction2[InputSplit, java.util.Iterator[(K, V)], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = { - new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, asJavaIterator(b)), + new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, b.asJava).asScala, preservesPartitioning)(fakeClassTag))(fakeClassTag) } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala index ec4f3964d75e0..0f49279f3e647 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaNewHadoopRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.java -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.hadoop.mapreduce.InputSplit @@ -37,7 +37,7 @@ class JavaNewHadoopRDD[K, V](rdd: NewHadoopRDD[K, V]) def mapPartitionsWithInputSplit[R]( f: JFunction2[InputSplit, java.util.Iterator[(K, V)], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = { - new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, asJavaIterator(b)), + new JavaRDD(rdd.mapPartitionsWithInputSplit((a, b) => f.call(a, b.asJava).asScala, preservesPartitioning)(fakeClassTag))(fakeClassTag) } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 8441bb3a3047e..fb787979c1820 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -20,7 +20,7 @@ package org.apache.spark.api.java import java.util.{Comparator, List => JList, Map => JMap} import java.lang.{Iterable => JIterable} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -142,7 +142,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double], seed: Long): JavaPairRDD[K, V] = - new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, seed)) + new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions.asScala, seed)) /** * Return a subset of this RDD sampled by key (via stratified sampling). @@ -173,7 +173,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) def sampleByKeyExact(withReplacement: Boolean, fractions: JMap[K, Double], seed: Long): JavaPairRDD[K, V] = - new JavaPairRDD[K, V](rdd.sampleByKeyExact(withReplacement, fractions, seed)) + new JavaPairRDD[K, V](rdd.sampleByKeyExact(withReplacement, fractions.asScala, seed)) /** * ::Experimental:: @@ -768,7 +768,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Return the list of values in the RDD for key `key`. This operation is done efficiently if the * RDD has a known partitioner by only searching the partition that the key maps to. */ - def lookup(key: K): JList[V] = seqAsJavaList(rdd.lookup(key)) + def lookup(key: K): JList[V] = rdd.lookup(key).asJava /** Output the RDD to any Hadoop-supported file system. */ def saveAsHadoopFile[F <: OutputFormat[_, _]]( @@ -987,30 +987,27 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) object JavaPairRDD { private[spark] def groupByResultToJava[K: ClassTag, T](rdd: RDD[(K, Iterable[T])]): RDD[(K, JIterable[T])] = { - rddToPairRDDFunctions(rdd).mapValues(asJavaIterable) + rddToPairRDDFunctions(rdd).mapValues(_.asJava) } private[spark] def cogroupResultToJava[K: ClassTag, V, W]( rdd: RDD[(K, (Iterable[V], Iterable[W]))]): RDD[(K, (JIterable[V], JIterable[W]))] = { - rddToPairRDDFunctions(rdd).mapValues(x => (asJavaIterable(x._1), asJavaIterable(x._2))) + rddToPairRDDFunctions(rdd).mapValues(x => (x._1.asJava, x._2.asJava)) } private[spark] def cogroupResult2ToJava[K: ClassTag, V, W1, W2]( rdd: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))]) : RDD[(K, (JIterable[V], JIterable[W1], JIterable[W2]))] = { - rddToPairRDDFunctions(rdd) - .mapValues(x => (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3))) + rddToPairRDDFunctions(rdd).mapValues(x => (x._1.asJava, x._2.asJava, x._3.asJava)) } private[spark] def cogroupResult3ToJava[K: ClassTag, V, W1, W2, W3]( rdd: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2], Iterable[W3]))]) : RDD[(K, (JIterable[V], JIterable[W1], JIterable[W2], JIterable[W3]))] = { - rddToPairRDDFunctions(rdd) - .mapValues(x => - (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3), asJavaIterable(x._4))) + rddToPairRDDFunctions(rdd).mapValues(x => (x._1.asJava, x._2.asJava, x._3.asJava, x._4.asJava)) } def fromRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): JavaPairRDD[K, V] = { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index c582488f16fe7..fc817cdd6a3f8 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -21,7 +21,6 @@ import java.{lang => jl} import java.lang.{Iterable => JIterable, Long => JLong} import java.util.{Comparator, List => JList, Iterator => JIterator} -import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -59,10 +58,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def rdd: RDD[T] @deprecated("Use partitions() instead.", "1.1.0") - def splits: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) + def splits: JList[Partition] = rdd.partitions.toSeq.asJava /** Set of partitions in this RDD. */ - def partitions: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) + def partitions: JList[Partition] = rdd.partitions.toSeq.asJava /** The partitioner of this RDD. */ def partitioner: Optional[Partitioner] = JavaUtils.optionToOptional(rdd.partitioner) @@ -82,7 +81,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * subclasses of RDD. */ def iterator(split: Partition, taskContext: TaskContext): java.util.Iterator[T] = - asJavaIterator(rdd.iterator(split, taskContext)) + rdd.iterator(split, taskContext).asJava // Transformations (return a new RDD) @@ -99,7 +98,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsWithIndex[R]( f: JFunction2[jl.Integer, java.util.Iterator[T], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = - new JavaRDD(rdd.mapPartitionsWithIndex(((a, b) => f(a, asJavaIterator(b))), + new JavaRDD(rdd.mapPartitionsWithIndex((a, b) => f.call(a, b.asJava).asScala, preservesPartitioning)(fakeClassTag))(fakeClassTag) /** @@ -153,7 +152,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = { def fn: (Iterator[T]) => Iterator[U] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } JavaRDD.fromRDD(rdd.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -164,7 +163,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U], preservesPartitioning: Boolean): JavaRDD[U] = { def fn: (Iterator[T]) => Iterator[U] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } JavaRDD.fromRDD( rdd.mapPartitions(fn, preservesPartitioning)(fakeClassTag[U]))(fakeClassTag[U]) @@ -175,7 +174,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = { def fn: (Iterator[T]) => Iterator[jl.Double] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: jl.Double) => x.doubleValue())) } @@ -186,7 +185,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]): JavaPairRDD[K2, V2] = { def fn: (Iterator[T]) => Iterator[(K2, V2)] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -197,7 +196,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]], preservesPartitioning: Boolean): JavaDoubleRDD = { def fn: (Iterator[T]) => Iterator[jl.Double] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } new JavaDoubleRDD(rdd.mapPartitions(fn, preservesPartitioning) .map(x => x.doubleValue())) @@ -209,7 +208,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2], preservesPartitioning: Boolean): JavaPairRDD[K2, V2] = { def fn: (Iterator[T]) => Iterator[(K2, V2)] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } JavaPairRDD.fromRDD( rdd.mapPartitions(fn, preservesPartitioning))(fakeClassTag[K2], fakeClassTag[V2]) @@ -219,14 +218,14 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Applies a function f to each partition of this RDD. */ def foreachPartition(f: VoidFunction[java.util.Iterator[T]]) { - rdd.foreachPartition((x => f.call(asJavaIterator(x)))) + rdd.foreachPartition((x => f.call(x.asJava))) } /** * Return an RDD created by coalescing all elements within each partition into an array. */ def glom(): JavaRDD[JList[T]] = - new JavaRDD(rdd.glom().map(x => new java.util.ArrayList[T](x.toSeq))) + new JavaRDD(rdd.glom().map(_.toSeq.asJava)) /** * Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of @@ -266,13 +265,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return an RDD created by piping elements to a forked external process. */ def pipe(command: JList[String]): JavaRDD[String] = - rdd.pipe(asScalaBuffer(command)) + rdd.pipe(command.asScala) /** * Return an RDD created by piping elements to a forked external process. */ def pipe(command: JList[String], env: java.util.Map[String, String]): JavaRDD[String] = - rdd.pipe(asScalaBuffer(command), mapAsScalaMap(env)) + rdd.pipe(command.asScala, env.asScala) /** * Zips this RDD with another one, returning key-value pairs with the first element in each RDD, @@ -294,8 +293,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { other: JavaRDDLike[U, _], f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V]): JavaRDD[V] = { def fn: (Iterator[T], Iterator[U]) => Iterator[V] = { - (x: Iterator[T], y: Iterator[U]) => asScalaIterator( - f.call(asJavaIterator(x), asJavaIterator(y)).iterator()) + (x: Iterator[T], y: Iterator[U]) => f.call(x.asJava, y.asJava).iterator().asScala } JavaRDD.fromRDD( rdd.zipPartitions(other.rdd)(fn)(other.classTag, fakeClassTag[V]))(fakeClassTag[V]) @@ -333,22 +331,16 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Return an array that contains all of the elements in this RDD. */ - def collect(): JList[T] = { - import scala.collection.JavaConversions._ - val arr: java.util.Collection[T] = rdd.collect().toSeq - new java.util.ArrayList(arr) - } + def collect(): JList[T] = + rdd.collect().toSeq.asJava /** * Return an iterator that contains all of the elements in this RDD. * * The iterator will consume as much memory as the largest partition in this RDD. */ - def toLocalIterator(): JIterator[T] = { - import scala.collection.JavaConversions._ - rdd.toLocalIterator - } - + def toLocalIterator(): JIterator[T] = + asJavaIteratorConverter(rdd.toLocalIterator).asJava /** * Return an array that contains all of the elements in this RDD. @@ -363,9 +355,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def collectPartitions(partitionIds: Array[Int]): Array[JList[T]] = { // This is useful for implementing `take` from other language frontends // like Python where the data is serialized. - import scala.collection.JavaConversions._ val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds) - res.map(x => new java.util.ArrayList(x.toSeq)).toArray + res.map(_.toSeq.asJava) } /** @@ -489,20 +480,14 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * it will be slow if a lot of partitions are required. In that case, use collect() to get the * whole RDD instead. */ - def take(num: Int): JList[T] = { - import scala.collection.JavaConversions._ - val arr: java.util.Collection[T] = rdd.take(num).toSeq - new java.util.ArrayList(arr) - } + def take(num: Int): JList[T] = + rdd.take(num).toSeq.asJava def takeSample(withReplacement: Boolean, num: Int): JList[T] = takeSample(withReplacement, num, Utils.random.nextLong) - def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = { - import scala.collection.JavaConversions._ - val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq - new java.util.ArrayList(arr) - } + def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = + rdd.takeSample(withReplacement, num, seed).toSeq.asJava /** * Return the first element in this RDD. @@ -582,10 +567,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * @return an array of top elements */ def top(num: Int, comp: Comparator[T]): JList[T] = { - import scala.collection.JavaConversions._ - val topElems = rdd.top(num)(Ordering.comparatorToOrdering(comp)) - val arr: java.util.Collection[T] = topElems.toSeq - new java.util.ArrayList(arr) + rdd.top(num)(Ordering.comparatorToOrdering(comp)).toSeq.asJava } /** @@ -607,10 +589,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * @return an array of top elements */ def takeOrdered(num: Int, comp: Comparator[T]): JList[T] = { - import scala.collection.JavaConversions._ - val topElems = rdd.takeOrdered(num)(Ordering.comparatorToOrdering(comp)) - val arr: java.util.Collection[T] = topElems.toSeq - new java.util.ArrayList(arr) + rdd.takeOrdered(num)(Ordering.comparatorToOrdering(comp)).toSeq.asJava } /** @@ -696,7 +675,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * applies a function f to each partition of this RDD. */ def foreachPartitionAsync(f: VoidFunction[java.util.Iterator[T]]): JavaFutureAction[Void] = { - new JavaFutureActionWrapper[Unit, Void](rdd.foreachPartitionAsync(x => f.call(x)), + new JavaFutureActionWrapper[Unit, Void](rdd.foreachPartitionAsync(x => f.call(x.asJava)), { x => null.asInstanceOf[Void] }) } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 02e49a853c5f7..609496ccdfef1 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -21,8 +21,7 @@ import java.io.Closeable import java.util import java.util.{Map => JMap} -import scala.collection.JavaConversions -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -104,7 +103,7 @@ class JavaSparkContext(val sc: SparkContext) */ def this(master: String, appName: String, sparkHome: String, jars: Array[String], environment: JMap[String, String]) = - this(new SparkContext(master, appName, sparkHome, jars.toSeq, environment, Map())) + this(new SparkContext(master, appName, sparkHome, jars.toSeq, environment.asScala, Map())) private[spark] val env = sc.env @@ -118,7 +117,7 @@ class JavaSparkContext(val sc: SparkContext) def appName: String = sc.appName - def jars: util.List[String] = sc.jars + def jars: util.List[String] = sc.jars.asJava def startTime: java.lang.Long = sc.startTime @@ -142,7 +141,7 @@ class JavaSparkContext(val sc: SparkContext) /** Distribute a local Scala collection to form an RDD. */ def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = { implicit val ctag: ClassTag[T] = fakeClassTag - sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices) + sc.parallelize(list.asScala, numSlices) } /** Get an RDD that has no partitions or elements. */ @@ -161,7 +160,7 @@ class JavaSparkContext(val sc: SparkContext) : JavaPairRDD[K, V] = { implicit val ctagK: ClassTag[K] = fakeClassTag implicit val ctagV: ClassTag[V] = fakeClassTag - JavaPairRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices)) + JavaPairRDD.fromRDD(sc.parallelize(list.asScala, numSlices)) } /** Distribute a local Scala collection to form an RDD. */ @@ -170,8 +169,7 @@ class JavaSparkContext(val sc: SparkContext) /** Distribute a local Scala collection to form an RDD. */ def parallelizeDoubles(list: java.util.List[java.lang.Double], numSlices: Int): JavaDoubleRDD = - JavaDoubleRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list).map(_.doubleValue()), - numSlices)) + JavaDoubleRDD.fromRDD(sc.parallelize(list.asScala.map(_.doubleValue()), numSlices)) /** Distribute a local Scala collection to form an RDD. */ def parallelizeDoubles(list: java.util.List[java.lang.Double]): JavaDoubleRDD = @@ -519,7 +517,7 @@ class JavaSparkContext(val sc: SparkContext) /** Build the union of two or more RDDs. */ override def union[T](first: JavaRDD[T], rest: java.util.List[JavaRDD[T]]): JavaRDD[T] = { - val rdds: Seq[RDD[T]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd) + val rdds: Seq[RDD[T]] = (Seq(first) ++ rest.asScala).map(_.rdd) implicit val ctag: ClassTag[T] = first.classTag sc.union(rdds) } @@ -527,7 +525,7 @@ class JavaSparkContext(val sc: SparkContext) /** Build the union of two or more RDDs. */ override def union[K, V](first: JavaPairRDD[K, V], rest: java.util.List[JavaPairRDD[K, V]]) : JavaPairRDD[K, V] = { - val rdds: Seq[RDD[(K, V)]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd) + val rdds: Seq[RDD[(K, V)]] = (Seq(first) ++ rest.asScala).map(_.rdd) implicit val ctag: ClassTag[(K, V)] = first.classTag implicit val ctagK: ClassTag[K] = first.kClassTag implicit val ctagV: ClassTag[V] = first.vClassTag @@ -536,7 +534,7 @@ class JavaSparkContext(val sc: SparkContext) /** Build the union of two or more RDDs. */ override def union(first: JavaDoubleRDD, rest: java.util.List[JavaDoubleRDD]): JavaDoubleRDD = { - val rdds: Seq[RDD[Double]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.srdd) + val rdds: Seq[RDD[Double]] = (Seq(first) ++ rest.asScala).map(_.srdd) new JavaDoubleRDD(sc.union(rdds)) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index b959b683d1674..a7dfa1d257cf2 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -17,15 +17,17 @@ package org.apache.spark.api.python -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD -import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.{Logging, SparkException} +import scala.collection.JavaConverters._ +import scala.util.{Failure, Success, Try} + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io._ -import scala.util.{Failure, Success, Try} -import org.apache.spark.annotation.Experimental +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.annotation.Experimental +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.util.{SerializableConfiguration, Utils} /** * :: Experimental :: @@ -68,7 +70,6 @@ private[python] class WritableToJavaConverter( * object representation */ private def convertWritable(writable: Writable): Any = { - import collection.JavaConversions._ writable match { case iw: IntWritable => iw.get() case dw: DoubleWritable => dw.get() @@ -89,9 +90,7 @@ private[python] class WritableToJavaConverter( aw.get().map(convertWritable(_)) case mw: MapWritable => val map = new java.util.HashMap[Any, Any]() - mw.foreach { case (k, v) => - map.put(convertWritable(k), convertWritable(v)) - } + mw.asScala.foreach { case (k, v) => map.put(convertWritable(k), convertWritable(v)) } map case w: Writable => WritableUtils.clone(w, conf.value.value) case other => other @@ -122,7 +121,6 @@ private[python] class JavaToWritableConverter extends Converter[Any, Writable] { * supported out-of-the-box. */ private def convertToWritable(obj: Any): Writable = { - import collection.JavaConversions._ obj match { case i: java.lang.Integer => new IntWritable(i) case d: java.lang.Double => new DoubleWritable(d) @@ -134,7 +132,7 @@ private[python] class JavaToWritableConverter extends Converter[Any, Writable] { case null => NullWritable.get() case map: java.util.Map[_, _] => val mapWritable = new MapWritable() - map.foreach { case (k, v) => + map.asScala.foreach { case (k, v) => mapWritable.put(convertToWritable(k), convertToWritable(v)) } mapWritable @@ -161,9 +159,8 @@ private[python] object PythonHadoopUtil { * Convert a [[java.util.Map]] of properties to a [[org.apache.hadoop.conf.Configuration]] */ def mapToConf(map: java.util.Map[String, String]): Configuration = { - import collection.JavaConversions._ val conf = new Configuration() - map.foreach{ case (k, v) => conf.set(k, v) } + map.asScala.foreach { case (k, v) => conf.set(k, v) } conf } @@ -172,9 +169,8 @@ private[python] object PythonHadoopUtil { * any matching keys in left */ def mergeConfs(left: Configuration, right: Configuration): Configuration = { - import collection.JavaConversions._ val copy = new Configuration(left) - right.iterator().foreach(entry => copy.set(entry.getKey, entry.getValue)) + right.asScala.foreach(entry => copy.set(entry.getKey, entry.getValue)) copy } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 2a56bf28d7027..b4d152b336602 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -21,7 +21,7 @@ import java.io._ import java.net._ import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.existentials @@ -66,11 +66,11 @@ private[spark] class PythonRDD( val env = SparkEnv.get val localdir = env.blockManager.diskBlockManager.localDirs.map( f => f.getPath()).mkString(",") - envVars += ("SPARK_LOCAL_DIRS" -> localdir) // it's also used in monitor thread + envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread if (reuse_worker) { - envVars += ("SPARK_REUSE_WORKER" -> "1") + envVars.put("SPARK_REUSE_WORKER", "1") } - val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap) + val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) // Whether is the worker released into idle pool @volatile var released = false @@ -150,7 +150,7 @@ private[spark] class PythonRDD( // Check whether the worker is ready to be re-used. if (stream.readInt() == SpecialLengths.END_OF_STREAM) { if (reuse_worker) { - env.releasePythonWorker(pythonExec, envVars.toMap, worker) + env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) released = true } } @@ -217,13 +217,13 @@ private[spark] class PythonRDD( // sparkFilesDir PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) // Python includes (*.zip and *.egg files) - dataOut.writeInt(pythonIncludes.length) - for (include <- pythonIncludes) { + dataOut.writeInt(pythonIncludes.size()) + for (include <- pythonIncludes.asScala) { PythonRDD.writeUTF(include, dataOut) } // Broadcast variables val oldBids = PythonRDD.getWorkerBroadcasts(worker) - val newBids = broadcastVars.map(_.id).toSet + val newBids = broadcastVars.asScala.map(_.id).toSet // number of different broadcasts val toRemove = oldBids.diff(newBids) val cnt = toRemove.size + newBids.diff(oldBids).size @@ -233,7 +233,7 @@ private[spark] class PythonRDD( dataOut.writeLong(- bid - 1) // bid >= 0 oldBids.remove(bid) } - for (broadcast <- broadcastVars) { + for (broadcast <- broadcastVars.asScala) { if (!oldBids.contains(broadcast.id)) { // send new broadcast dataOut.writeLong(broadcast.id) @@ -287,7 +287,7 @@ private[spark] class PythonRDD( if (!context.isCompleted) { try { logWarning("Incomplete task interrupted: Attempting to kill Python Worker") - env.destroyPythonWorker(pythonExec, envVars.toMap, worker) + env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) } catch { case e: Exception => logError("Exception when trying to kill worker", e) @@ -358,10 +358,10 @@ private[spark] object PythonRDD extends Logging { type ByteArray = Array[Byte] type UnrolledPartition = Array[ByteArray] val allPartitions: Array[UnrolledPartition] = - sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions) + sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions.asScala) val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*) serveIterator(flattenedPartition.iterator, - s"serve RDD ${rdd.id} with partitions ${partitions.mkString(",")}") + s"serve RDD ${rdd.id} with partitions ${partitions.asScala.mkString(",")}") } /** @@ -819,7 +819,7 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: val in = socket.getInputStream val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize)) out.writeInt(val2.size) - for (array <- val2) { + for (array <- val2.asScala) { out.writeInt(array.length) out.write(array) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 90dacaeb93429..31e534f160eeb 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -17,10 +17,10 @@ package org.apache.spark.api.python -import java.io.{File} +import java.io.File import java.util.{List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkContext @@ -51,7 +51,14 @@ private[spark] object PythonUtils { * Convert list of T into seq of T (for calling API with varargs) */ def toSeq[T](vs: JList[T]): Seq[T] = { - vs.toList.toSeq + vs.asScala + } + + /** + * Convert list of T into a (Scala) List of T + */ + def toList[T](vs: JList[T]): List[T] = { + vs.asScala.toList } /** @@ -65,6 +72,6 @@ private[spark] object PythonUtils { * Convert java map of K, V into Map of K, V (for calling API with varargs) */ def toScalaMap[K, V](jm: java.util.Map[K, V]): Map[K, V] = { - jm.toMap + jm.asScala.toMap } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index e314408c067e9..7039b734d2e40 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -19,9 +19,10 @@ package org.apache.spark.api.python import java.io.{DataOutputStream, DataInputStream, InputStream, OutputStreamWriter} import java.net.{InetAddress, ServerSocket, Socket, SocketException} +import java.util.Arrays import scala.collection.mutable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark._ import org.apache.spark.util.{RedirectThread, Utils} @@ -108,9 +109,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) // Create and start the worker - val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.worker")) + val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", "pyspark.worker")) val workerEnv = pb.environment() - workerEnv.putAll(envVars) + workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: workerEnv.put("PYTHONUNBUFFERED", "YES") @@ -151,9 +152,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String try { // Create and start the daemon - val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.daemon")) + val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", "pyspark.daemon")) val workerEnv = pb.environment() - workerEnv.putAll(envVars) + workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: workerEnv.put("PYTHONUNBUFFERED", "YES") diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index 1f1debcf84ad4..fd27276e70bfe 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -22,7 +22,6 @@ import java.util.{ArrayList => JArrayList} import org.apache.spark.api.java.JavaRDD -import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Failure @@ -214,7 +213,7 @@ private[spark] object SerDeUtil extends Logging { new AutoBatchedPickler(cleaned) } else { val pickle = new Pickler - cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched))) + cleaned.grouped(batchSize).map(batched => pickle.dumps(batched.asJava)) } } } diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index 8f30ff9202c83..ee1fb056f0d96 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -20,6 +20,8 @@ package org.apache.spark.api.python import java.io.{DataOutput, DataInput} import java.{util => ju} +import scala.collection.JavaConverters._ + import com.google.common.base.Charsets.UTF_8 import org.apache.hadoop.io._ @@ -62,10 +64,9 @@ private[python] class TestInputKeyConverter extends Converter[Any, Any] { } private[python] class TestInputValueConverter extends Converter[Any, Any] { - import collection.JavaConversions._ override def convert(obj: Any): ju.List[Double] = { val m = obj.asInstanceOf[MapWritable] - seqAsJavaList(m.keySet.map(w => w.asInstanceOf[DoubleWritable].get()).toSeq) + m.keySet.asScala.map(_.asInstanceOf[DoubleWritable].get()).toSeq.asJava } } @@ -76,9 +77,8 @@ private[python] class TestOutputKeyConverter extends Converter[Any, Any] { } private[python] class TestOutputValueConverter extends Converter[Any, Any] { - import collection.JavaConversions._ override def convert(obj: Any): DoubleWritable = { - new DoubleWritable(obj.asInstanceOf[java.util.Map[Double, _]].keySet().head) + new DoubleWritable(obj.asInstanceOf[java.util.Map[Double, _]].keySet().iterator().next()) } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 1cf2824f862ee..9d5bbb5d609f3 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -19,9 +19,10 @@ package org.apache.spark.api.r import java.io._ import java.net.{InetAddress, ServerSocket} +import java.util.Arrays import java.util.{Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.io.Source import scala.reflect.ClassTag import scala.util.Try @@ -365,11 +366,11 @@ private[r] object RRDD { sparkConf.setIfMissing("spark.master", "local") } - for ((name, value) <- sparkEnvirMap) { - sparkConf.set(name.asInstanceOf[String], value.asInstanceOf[String]) + for ((name, value) <- sparkEnvirMap.asScala) { + sparkConf.set(name.toString, value.toString) } - for ((name, value) <- sparkExecutorEnvMap) { - sparkConf.setExecutorEnv(name.asInstanceOf[String], value.asInstanceOf[String]) + for ((name, value) <- sparkExecutorEnvMap.asScala) { + sparkConf.setExecutorEnv(name.toString, value.toString) } val jsc = new JavaSparkContext(sparkConf) @@ -395,7 +396,7 @@ private[r] object RRDD { val rOptions = "--vanilla" val rLibDir = RUtils.sparkRPackagePath(isDriver = false) val rExecScript = rLibDir + "/SparkR/worker/" + script - val pb = new ProcessBuilder(List(rCommand, rOptions, rExecScript)) + val pb = new ProcessBuilder(Arrays.asList(rCommand, rOptions, rExecScript)) // Unset the R_TESTS environment variable for workers. // This is set by R CMD check as startup.Rs // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R) diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala index 427b2bc7cbcbb..9e807cc52f18c 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -18,8 +18,7 @@ package org.apache.spark.api.r import java.io.File - -import scala.collection.JavaConversions._ +import java.util.Arrays import org.apache.spark.{SparkEnv, SparkException} @@ -68,7 +67,7 @@ private[spark] object RUtils { /** Check if R is installed before running tests that use R commands. */ def isRInstalled: Boolean = { try { - val builder = new ProcessBuilder(Seq("R", "--version")) + val builder = new ProcessBuilder(Arrays.asList("R", "--version")) builder.start().waitFor() == 0 } catch { case e: Exception => false diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 3c89f24473744..dbbbcf40c1e96 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -20,7 +20,7 @@ package org.apache.spark.api.r import java.io.{DataInputStream, DataOutputStream} import java.sql.{Timestamp, Date, Time} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ /** * Utility functions to serialize, deserialize objects to / from R @@ -165,7 +165,7 @@ private[spark] object SerDe { val valueType = readObjectType(in) readTypedObject(in, valueType) }) - mapAsJavaMap(keys.zip(values).toMap) + keys.zip(values).toMap.asJava } else { new java.util.HashMap[Object, Object]() } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index a0c9b5e63c744..7e3764d802fe1 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -20,7 +20,7 @@ package org.apache.spark.broadcast import java.io._ import java.nio.ByteBuffer -import scala.collection.JavaConversions.asJavaEnumeration +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.util.Random @@ -210,7 +210,7 @@ private object TorrentBroadcast extends Logging { compressionCodec: Option[CompressionCodec]): T = { require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks") val is = new SequenceInputStream( - asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream(block)))) + blocks.iterator.map(new ByteBufferInputStream(_)).asJavaEnumeration) val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is) val ser = serializer.newInstance() val serIn = ser.deserializeStream(in) diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index 22ef701d833b2..6840a3ae831f0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -19,13 +19,13 @@ package org.apache.spark.deploy import java.util.concurrent.CountDownLatch -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.{Logging, SparkConf, SecurityManager} import org.apache.spark.network.TransportContext import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.sasl.SaslServerBootstrap -import org.apache.spark.network.server.TransportServer +import org.apache.spark.network.server.{TransportServerBootstrap, TransportServer} import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler import org.apache.spark.network.util.TransportConf import org.apache.spark.util.Utils @@ -67,13 +67,13 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana def start() { require(server == null, "Shuffle server already started") logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") - val bootstraps = + val bootstraps: Seq[TransportServerBootstrap] = if (useSasl) { Seq(new SaslServerBootstrap(transportConf, securityManager)) } else { Nil } - server = transportContext.createServer(port, bootstraps) + server = transportContext.createServer(port, bootstraps.asJava) } /** Clean up all shuffle files associated with an application that has exited. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 23d01e9cbb9f9..d85327603f64d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -21,7 +21,7 @@ import java.net.URI import java.io.File import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.Try import org.apache.spark.SparkUserAppException @@ -71,7 +71,7 @@ object PythonRunner { val pythonPath = PythonUtils.mergePythonPaths(pathElements: _*) // Launch Python process - val builder = new ProcessBuilder(Seq(pythonExec, formattedPythonFile) ++ otherArgs) + val builder = new ProcessBuilder((Seq(pythonExec, formattedPythonFile) ++ otherArgs).asJava) val env = builder.environment() env.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: diff --git a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala index ed1e972955679..4b28866dcaa7c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala @@ -22,7 +22,7 @@ import java.util.jar.JarFile import java.util.logging.Level import java.util.zip.{ZipEntry, ZipOutputStream} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import com.google.common.io.{ByteStreams, Files} @@ -110,7 +110,7 @@ private[deploy] object RPackageUtils extends Logging { print(s"Building R package with the command: $installCmd", printStream) } try { - val builder = new ProcessBuilder(installCmd) + val builder = new ProcessBuilder(installCmd.asJava) builder.redirectErrorStream(true) val env = builder.environment() env.clear() diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index c0cab22fa8252..05b954ce36998 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy import java.io._ import java.util.concurrent.{Semaphore, TimeUnit} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path @@ -68,7 +68,7 @@ object RRunner { if (initialized.tryAcquire(backendTimeout, TimeUnit.SECONDS)) { // Launch R val returnCode = try { - val builder = new ProcessBuilder(Seq(rCommand, rFileNormalized) ++ otherArgs) + val builder = new ProcessBuilder((Seq(rCommand, rFileNormalized) ++ otherArgs).asJava) val env = builder.environment() env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString) val rPackageDir = RUtils.sparkRPackagePath(isDriver = true) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala index b8d3993540220..8d5e716e6aea4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.curator.framework.{CuratorFramework, CuratorFrameworkFactory} import org.apache.curator.retry.ExponentialBackoffRetry @@ -57,7 +57,7 @@ private[spark] object SparkCuratorUtil extends Logging { def deleteRecursive(zk: CuratorFramework, path: String) { if (zk.checkExists().forPath(path) != null) { - for (child <- zk.getChildren.forPath(path)) { + for (child <- zk.getChildren.forPath(path).asScala) { zk.delete().forPath(path + "/" + child) } zk.delete().forPath(path) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index dda4216c7efe2..f7723ef5bde4c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -22,7 +22,7 @@ import java.lang.reflect.Method import java.security.PrivilegedExceptionAction import java.util.{Arrays, Comparator} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.control.NonFatal @@ -71,7 +71,7 @@ class SparkHadoopUtil extends Logging { } def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) { - for (token <- source.getTokens()) { + for (token <- source.getTokens.asScala) { dest.addToken(token) } } @@ -175,8 +175,8 @@ class SparkHadoopUtil extends Logging { } private def getFileSystemThreadStatistics(): Seq[AnyRef] = { - val stats = FileSystem.getAllStatistics() - stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics")) + FileSystem.getAllStatistics.asScala.map( + Utils.invoke(classOf[Statistics], _, "getThreadStatistics")) } private def getFileSystemThreadStatisticsMethod(methodName: String): Method = { @@ -306,12 +306,13 @@ class SparkHadoopUtil extends Logging { val renewalInterval = sparkConf.getLong("spark.yarn.token.renewal.interval", (24 hours).toMillis) - credentials.getAllTokens.filter(_.getKind == DelegationTokenIdentifier.HDFS_DELEGATION_KIND) + credentials.getAllTokens.asScala + .filter(_.getKind == DelegationTokenIdentifier.HDFS_DELEGATION_KIND) .map { t => - val identifier = new DelegationTokenIdentifier() - identifier.readFields(new DataInputStream(new ByteArrayInputStream(t.getIdentifier))) - (identifier.getIssueDate + fraction * renewalInterval).toLong - now - }.foldLeft(0L)(math.max) + val identifier = new DelegationTokenIdentifier() + identifier.readFields(new DataInputStream(new ByteArrayInputStream(t.getIdentifier))) + (identifier.getIssueDate + fraction * renewalInterval).toLong - now + }.foldLeft(0L)(math.max) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 3f3c6627c21fb..18a1c52ae53fb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -23,7 +23,7 @@ import java.net.URI import java.util.{List => JList} import java.util.jar.JarFile -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.io.Source @@ -94,7 +94,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S // Set parameters from command line arguments try { - parse(args.toList) + parse(args.asJava) } catch { case e: IllegalArgumentException => SparkSubmit.printErrorAndExit(e.getMessage()) @@ -458,7 +458,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } override protected def handleExtraArgs(extra: JList[String]): Unit = { - childArgs ++= extra + childArgs ++= extra.asScala } private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 563831cc6b8dd..540e802420ce0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.master import java.nio.ByteBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.curator.framework.CuratorFramework @@ -49,8 +49,8 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializer } override def read[T: ClassTag](prefix: String): Seq[T] = { - val file = zk.getChildren.forPath(WORKING_DIR).filter(_.startsWith(prefix)) - file.map(deserializeFromFile[T]).flatten + zk.getChildren.forPath(WORKING_DIR).asScala + .filter(_.startsWith(prefix)).map(deserializeFromFile[T]).flatten } override def close() { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala index 45a3f43045437..ce02ee203a4bd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -18,9 +18,8 @@ package org.apache.spark.deploy.worker import java.io.{File, FileOutputStream, InputStream, IOException} -import java.lang.System._ -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.Map import org.apache.spark.Logging @@ -62,7 +61,7 @@ object CommandUtils extends Logging { // SPARK-698: do not call the run.cmd script, as process.destroy() // fails to kill a process tree on Windows val cmd = new WorkerCommandBuilder(sparkHome, memory, command).buildCommand() - cmd.toSeq ++ Seq(command.mainClass) ++ command.arguments + cmd.asScala ++ Seq(command.mainClass) ++ command.arguments } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index ec51c3d935d8e..89159ff5e2b3c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.worker import java.io._ -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files @@ -172,8 +172,8 @@ private[deploy] class DriverRunner( CommandUtils.redirectStream(process.getInputStream, stdout) val stderr = new File(baseDir, "stderr") - val header = "Launch Command: %s\n%s\n\n".format( - builder.command.mkString("\"", "\" \"", "\""), "=" * 40) + val formattedCommand = builder.command.asScala.mkString("\"", "\" \"", "\"") + val header = "Launch Command: %s\n%s\n\n".format(formattedCommand, "=" * 40) Files.append(header, stderr, UTF_8) CommandUtils.redirectStream(process.getErrorStream, stderr) } @@ -229,6 +229,6 @@ private[deploy] trait ProcessBuilderLike { private[deploy] object ProcessBuilderLike { def apply(processBuilder: ProcessBuilder): ProcessBuilderLike = new ProcessBuilderLike { override def start(): Process = processBuilder.start() - override def command: Seq[String] = processBuilder.command() + override def command: Seq[String] = processBuilder.command().asScala } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index ab3fea475c2a5..3aef0515cbf6e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.worker import java.io._ -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files @@ -129,7 +129,8 @@ private[deploy] class ExecutorRunner( val builder = CommandUtils.buildProcessBuilder(appDesc.command, new SecurityManager(conf), memory, sparkHome.getAbsolutePath, substituteVariables) val command = builder.command() - logInfo("Launch command: " + command.mkString("\"", "\" \"", "\"")) + val formattedCommand = command.asScala.mkString("\"", "\" \"", "\"") + logInfo(s"Launch command: $formattedCommand") builder.directory(executorDir) builder.environment.put("SPARK_EXECUTOR_DIRS", appLocalDirs.mkString(File.pathSeparator)) @@ -145,7 +146,7 @@ private[deploy] class ExecutorRunner( process = builder.start() val header = "Spark Executor Command: %s\n%s\n\n".format( - command.mkString("\"", "\" \"", "\""), "=" * 40) + formattedCommand, "=" * 40) // Redirect its stdout and stderr to files val stdout = new File(executorDir, "stdout") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 79b1536d94016..770927c80f7a4 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -24,7 +24,6 @@ import java.util.{UUID, Date} import java.util.concurrent._ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} -import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap} import scala.concurrent.ExecutionContext import scala.util.Random diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 42a85e42ea2b6..c3491bb8b1cf3 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -23,7 +23,7 @@ import java.net.URL import java.nio.ByteBuffer import java.util.concurrent.{ConcurrentHashMap, TimeUnit} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal @@ -147,7 +147,7 @@ private[spark] class Executor( /** Returns the total amount of time this JVM process has spent in garbage collection. */ private def computeTotalGcTime(): Long = { - ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum + ManagementFactory.getGarbageCollectorMXBeans.asScala.map(_.getCollectionTime).sum } class TaskRunner( @@ -425,7 +425,7 @@ private[spark] class Executor( val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]() val curGCTime = computeTotalGcTime() - for (taskRunner <- runningTasks.values()) { + for (taskRunner <- runningTasks.values().asScala) { if (taskRunner.task != null) { taskRunner.task.metrics.foreach { metrics => metrics.updateShuffleReadMetrics() diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala index 293c512f8b70c..d16f4a1fc4e3b 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala @@ -19,7 +19,7 @@ package org.apache.spark.executor import java.util.concurrent.ThreadPoolExecutor -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.hadoop.fs.FileSystem @@ -30,7 +30,7 @@ private[spark] class ExecutorSource(threadPool: ThreadPoolExecutor, executorId: String) extends Source { private def fileStats(scheme: String) : Option[FileSystem.Statistics] = - FileSystem.getAllStatistics().find(s => s.getScheme.equals(scheme)) + FileSystem.getAllStatistics.asScala.find(s => s.getScheme.equals(scheme)) private def registerFileSystemStat[T]( scheme: String, name: String, f: FileSystem.Statistics => T, defaultValue: T) = { diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index cfd672e1d8a97..0474fd2ccc12e 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -19,7 +19,7 @@ package org.apache.spark.executor import java.nio.ByteBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.mesos.protobuf.ByteString import org.apache.mesos.{Executor => MesosExecutor, ExecutorDriver, MesosExecutorDriver} @@ -28,7 +28,7 @@ import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _} import org.apache.spark.{Logging, TaskState, SparkConf, SparkEnv} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.scheduler.cluster.mesos.{MesosTaskLaunchData} +import org.apache.spark.scheduler.cluster.mesos.MesosTaskLaunchData import org.apache.spark.util.{SignalLogger, Utils} private[spark] class MesosExecutorBackend @@ -55,7 +55,7 @@ private[spark] class MesosExecutorBackend slaveInfo: SlaveInfo) { // Get num cores for this task from ExecutorInfo, created in MesosSchedulerBackend. - val cpusPerTask = executorInfo.getResourcesList + val cpusPerTask = executorInfo.getResourcesList.asScala .find(_.getName == "cpus") .map(_.getScalar.getValue.toInt) .getOrElse(0) diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index 6cda7772f77bc..a5ad47293f1c2 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -19,7 +19,7 @@ package org.apache.spark.input import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import com.google.common.io.ByteStreams import org.apache.hadoop.conf.Configuration @@ -44,12 +44,9 @@ private[spark] abstract class StreamFileInputFormat[T] * which is set through setMaxSplitSize */ def setMinPartitions(context: JobContext, minPartitions: Int) { - val files = listStatus(context) - val totalLen = files.map { file => - if (file.isDir) 0L else file.getLen - }.sum - - val maxSplitSize = Math.ceil(totalLen * 1.0 / files.length).toLong + val files = listStatus(context).asScala + val totalLen = files.map(file => if (file.isDir) 0L else file.getLen).sum + val maxSplitSize = Math.ceil(totalLen * 1.0 / files.size).toLong super.setMaxSplitSize(maxSplitSize) } diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala index aaef7c74eea33..1ba34a11414a2 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala @@ -17,7 +17,7 @@ package org.apache.spark.input -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.InputSplit @@ -52,10 +52,8 @@ private[spark] class WholeTextFileInputFormat * which is set through setMaxSplitSize */ def setMinPartitions(context: JobContext, minPartitions: Int) { - val files = listStatus(context) - val totalLen = files.map { file => - if (file.isDir) 0L else file.getLen - }.sum + val files = listStatus(context).asScala + val totalLen = files.map(file => if (file.isDir) 0L else file.getLen).sum val maxSplitSize = Math.ceil(totalLen * 1.0 / (if (minPartitions == 0) 1 else minPartitions)).toLong super.setMaxSplitSize(maxSplitSize) diff --git a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala index 9be98723aed14..0c096656f9236 100644 --- a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala +++ b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala @@ -20,7 +20,7 @@ package org.apache.spark.launcher import java.io.File import java.util.{HashMap => JHashMap, List => JList, Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.deploy.Command @@ -32,7 +32,7 @@ import org.apache.spark.deploy.Command private[spark] class WorkerCommandBuilder(sparkHome: String, memoryMb: Int, command: Command) extends AbstractCommandBuilder { - childEnv.putAll(command.environment) + childEnv.putAll(command.environment.asJava) childEnv.put(CommandBuilderUtils.ENV_SPARK_HOME, sparkHome) override def buildCommand(env: JMap[String, String]): JList[String] = { diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala index d7495551ad233..dd2d325d87034 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala @@ -20,6 +20,7 @@ package org.apache.spark.metrics import java.io.{FileInputStream, InputStream} import java.util.Properties +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.matching.Regex @@ -58,25 +59,20 @@ private[spark] class MetricsConfig(conf: SparkConf) extends Logging { propertyCategories = subProperties(properties, INSTANCE_REGEX) if (propertyCategories.contains(DEFAULT_PREFIX)) { - import scala.collection.JavaConversions._ - - val defaultProperty = propertyCategories(DEFAULT_PREFIX) - for { (inst, prop) <- propertyCategories - if (inst != DEFAULT_PREFIX) - (k, v) <- defaultProperty - if (prop.getProperty(k) == null) } { - prop.setProperty(k, v) + val defaultProperty = propertyCategories(DEFAULT_PREFIX).asScala + for((inst, prop) <- propertyCategories if (inst != DEFAULT_PREFIX); + (k, v) <- defaultProperty if (prop.get(k) == null)) { + prop.put(k, v) } } } def subProperties(prop: Properties, regex: Regex): mutable.HashMap[String, Properties] = { val subProperties = new mutable.HashMap[String, Properties] - import scala.collection.JavaConversions._ - prop.foreach { kv => - if (regex.findPrefixOf(kv._1).isDefined) { - val regex(prefix, suffix) = kv._1 - subProperties.getOrElseUpdate(prefix, new Properties).setProperty(suffix, kv._2) + prop.asScala.foreach { kv => + if (regex.findPrefixOf(kv._1.toString).isDefined) { + val regex(prefix, suffix) = kv._1.toString + subProperties.getOrElseUpdate(prefix, new Properties).setProperty(suffix, kv._2.toString) } } subProperties diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index b089da8596e2b..7c170a742fb64 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -19,7 +19,7 @@ package org.apache.spark.network.netty import java.nio.ByteBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.Logging import org.apache.spark.network.BlockDataManager @@ -55,7 +55,7 @@ class NettyBlockRpcServer( case openBlocks: OpenBlocks => val blocks: Seq[ManagedBuffer] = openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) - val streamId = streamManager.registerStream(blocks.iterator) + val streamId = streamManager.registerStream(blocks.iterator.asJava) logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index d650d5fe73087..ff8aae9ebe9f0 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,7 +17,7 @@ package org.apache.spark.network.netty -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.concurrent.{Future, Promise} import org.apache.spark.{SecurityManager, SparkConf} @@ -58,7 +58,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage securityManager.isSaslEncryptionEnabled())) } transportContext = new TransportContext(transportConf, rpcHandler) - clientFactory = transportContext.createClientFactory(clientBootstrap.toList) + clientFactory = transportContext.createClientFactory(clientBootstrap.toSeq.asJava) server = createServer(serverBootstrap.toList) appId = conf.getAppId logInfo("Server created on " + server.getPort) @@ -67,7 +67,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage /** Creates and binds the TransportServer, possibly trying multiple ports. */ private def createServer(bootstraps: List[TransportServerBootstrap]): TransportServer = { def startService(port: Int): (TransportServer, Int) = { - val server = transportContext.createServer(port, bootstraps) + val server = transportContext.createServer(port, bootstraps.asJava) (server, server.getPort) } diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index 1499da07bb83b..8d9ebadaf79d4 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -23,7 +23,7 @@ import java.nio.channels._ import java.util.concurrent.ConcurrentLinkedQueue import java.util.LinkedList -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal @@ -145,7 +145,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, } def callOnExceptionCallbacks(e: Throwable) { - onExceptionCallbacks foreach { + onExceptionCallbacks.asScala.foreach { callback => try { callback(this, e) diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala index 91b07ce3af1b6..5afce75680f94 100644 --- a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.partial import java.util.{HashMap => JHashMap} -import scala.collection.JavaConversions.mapAsScalaMap +import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.HashMap import scala.reflect.ClassTag @@ -48,9 +48,9 @@ private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, conf if (outputsMerged == totalOutputs) { val result = new JHashMap[T, BoundedDouble](sums.size) sums.foreach { case (key, sum) => - result(key) = new BoundedDouble(sum, 1.0, sum, sum) + result.put(key, new BoundedDouble(sum, 1.0, sum, sum)) } - result + result.asScala } else if (outputsMerged == 0) { new HashMap[T, BoundedDouble] } else { @@ -64,9 +64,9 @@ private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, conf val stdev = math.sqrt(variance) val low = mean - confFactor * stdev val high = mean + confFactor * stdev - result(key) = new BoundedDouble(mean, confidence, low, high) + result.put(key, new BoundedDouble(mean, confidence, low, high)) } - result + result.asScala } } } diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala index af26c3d59ac02..a164040684803 100644 --- a/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.partial import java.util.{HashMap => JHashMap} -import scala.collection.JavaConversions.mapAsScalaMap +import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.HashMap @@ -55,9 +55,9 @@ private[spark] class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Doub while (iter.hasNext) { val entry = iter.next() val mean = entry.getValue.mean - result(entry.getKey) = new BoundedDouble(mean, 1.0, mean, mean) + result.put(entry.getKey, new BoundedDouble(mean, 1.0, mean, mean)) } - result + result.asScala } else if (outputsMerged == 0) { new HashMap[T, BoundedDouble] } else { @@ -72,9 +72,9 @@ private[spark] class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Doub val confFactor = studentTCacher.get(counter.count) val low = mean - confFactor * stdev val high = mean + confFactor * stdev - result(entry.getKey) = new BoundedDouble(mean, confidence, low, high) + result.put(entry.getKey, new BoundedDouble(mean, confidence, low, high)) } - result + result.asScala } } } diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala index 442fb86227d86..54a1beab3514b 100644 --- a/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.partial import java.util.{HashMap => JHashMap} -import scala.collection.JavaConversions.mapAsScalaMap +import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.HashMap @@ -55,9 +55,9 @@ private[spark] class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Doubl while (iter.hasNext) { val entry = iter.next() val sum = entry.getValue.sum - result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum) + result.put(entry.getKey, new BoundedDouble(sum, 1.0, sum, sum)) } - result + result.asScala } else if (outputsMerged == 0) { new HashMap[T, BoundedDouble] } else { @@ -80,9 +80,9 @@ private[spark] class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Doubl val confFactor = studentTCacher.get(counter.count) val low = sumEstimate - confFactor * sumStdev val high = sumEstimate + confFactor * sumStdev - result(entry.getKey) = new BoundedDouble(sumEstimate, confidence, low, high) + result.put(entry.getKey, new BoundedDouble(sumEstimate, confidence, low, high)) } - result + result.asScala } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 326fafb230a40..4e5f2e8a5d467 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -22,7 +22,7 @@ import java.text.SimpleDateFormat import java.util.{Date, HashMap => JHashMap} import scala.collection.{Map, mutable} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import scala.util.DynamicVariable @@ -312,14 +312,14 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } : Iterator[JHashMap[K, V]] val mergeMaps = (m1: JHashMap[K, V], m2: JHashMap[K, V]) => { - m2.foreach { pair => + m2.asScala.foreach { pair => val old = m1.get(pair._1) m1.put(pair._1, if (old == null) pair._2 else cleanedF(old, pair._2)) } m1 } : JHashMap[K, V] - self.mapPartitions(reducePartition).reduce(mergeMaps) + self.mapPartitions(reducePartition).reduce(mergeMaps).asScala } /** Alias for reduceByKeyLocally */ diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index 3bb9998e1db44..afbe566b76566 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -23,7 +23,7 @@ import java.io.IOException import java.io.PrintWriter import java.util.StringTokenizer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.ArrayBuffer import scala.io.Source @@ -72,7 +72,7 @@ private[spark] class PipedRDD[T: ClassTag]( } override def compute(split: Partition, context: TaskContext): Iterator[String] = { - val pb = new ProcessBuilder(command) + val pb = new ProcessBuilder(command.asJava) // Add the environmental variables to the process. val currentEnvVars = pb.environment() envVars.foreach { case (variable, value) => currentEnvVars.put(variable, value) } @@ -81,7 +81,7 @@ private[spark] class PipedRDD[T: ClassTag]( // so the user code can access the input filename if (split.isInstanceOf[HadoopPartition]) { val hadoopSplit = split.asInstanceOf[HadoopPartition] - currentEnvVars.putAll(hadoopSplit.getPipeEnvVars()) + currentEnvVars.putAll(hadoopSplit.getPipeEnvVars().asJava) } // When spark.worker.separated.working.directory option is turned on, each diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index f7cb1791d4ac6..9a4fa301b06e3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.util.{HashMap => JHashMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag @@ -125,7 +125,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( integrate(0, t => getSeq(t._1) += t._2) // the second dep is rdd2; remove all of its keys integrate(1, t => map.remove(t._1)) - map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten + map.asScala.iterator.map(t => t._2.iterator.map((t._1, _))).flatten } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala index bac37bfdaa23f..0e438ab4366d9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.immutable.Set import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -107,7 +107,7 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl val retval = new ArrayBuffer[SplitInfo]() val list = instance.getSplits(job) - for (split <- list) { + for (split <- list.asScala) { retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, split) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 174b73221afc0..5821afea98982 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.Logging @@ -74,7 +74,7 @@ private[spark] class Pool( if (schedulableNameToSchedulable.containsKey(schedulableName)) { return schedulableNameToSchedulable.get(schedulableName) } - for (schedulable <- schedulableQueue) { + for (schedulable <- schedulableQueue.asScala) { val sched = schedulable.getSchedulableByName(schedulableName) if (sched != null) { return sched @@ -84,12 +84,12 @@ private[spark] class Pool( } override def executorLost(executorId: String, host: String) { - schedulableQueue.foreach(_.executorLost(executorId, host)) + schedulableQueue.asScala.foreach(_.executorLost(executorId, host)) } override def checkSpeculatableTasks(): Boolean = { var shouldRevive = false - for (schedulable <- schedulableQueue) { + for (schedulable <- schedulableQueue.asScala) { shouldRevive |= schedulable.checkSpeculatableTasks() } shouldRevive @@ -98,7 +98,7 @@ private[spark] class Pool( override def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager] = { var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager] val sortedSchedulableQueue = - schedulableQueue.toSeq.sortWith(taskSetSchedulingAlgorithm.comparator) + schedulableQueue.asScala.toSeq.sortWith(taskSetSchedulingAlgorithm.comparator) for (schedulable <- sortedSchedulableQueue) { sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index d6e1e9e5bebc2..452c32d5411cd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -21,7 +21,7 @@ import java.io.File import java.util.concurrent.locks.ReentrantLock import java.util.{Collections, List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, HashSet} import com.google.common.collect.HashBiMap @@ -233,7 +233,7 @@ private[spark] class CoarseMesosSchedulerBackend( override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { stateLock.synchronized { val filters = Filters.newBuilder().setRefuseSeconds(5).build() - for (offer <- offers) { + for (offer <- offers.asScala) { val offerAttributes = toAttributeMap(offer.getAttributesList) val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) val slaveId = offer.getSlaveId.getValue @@ -251,21 +251,21 @@ private[spark] class CoarseMesosSchedulerBackend( val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired) totalCoresAcquired += cpusToUse val taskId = newMesosTaskId() - taskIdToSlaveId(taskId) = slaveId + taskIdToSlaveId.put(taskId, slaveId) slaveIdsWithExecutors += slaveId coresByTaskId(taskId) = cpusToUse // Gather cpu resources from the available resources and use them in the task. val (remainingResources, cpuResourcesToUse) = partitionResources(offer.getResourcesList, "cpus", cpusToUse) val (_, memResourcesToUse) = - partitionResources(remainingResources, "mem", calculateTotalMemory(sc)) + partitionResources(remainingResources.asJava, "mem", calculateTotalMemory(sc)) val taskBuilder = MesosTaskInfo.newBuilder() .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) .setSlaveId(offer.getSlaveId) .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId)) .setName("Task " + taskId) - .addAllResources(cpuResourcesToUse) - .addAllResources(memResourcesToUse) + .addAllResources(cpuResourcesToUse.asJava) + .addAllResources(memResourcesToUse.asJava) sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => MesosSchedulerBackendUtil @@ -314,9 +314,9 @@ private[spark] class CoarseMesosSchedulerBackend( } if (TaskState.isFinished(TaskState.fromMesos(state))) { - val slaveId = taskIdToSlaveId(taskId) + val slaveId = taskIdToSlaveId.get(taskId) slaveIdsWithExecutors -= slaveId - taskIdToSlaveId -= taskId + taskIdToSlaveId.remove(taskId) // Remove the cores we have remembered for this task, if it's in the hashmap for (cores <- coresByTaskId.get(taskId)) { totalCoresAcquired -= cores @@ -361,7 +361,7 @@ private[spark] class CoarseMesosSchedulerBackend( stateLock.synchronized { if (slaveIdsWithExecutors.contains(slaveId)) { val slaveIdToTaskId = taskIdToSlaveId.inverse() - if (slaveIdToTaskId.contains(slaveId)) { + if (slaveIdToTaskId.containsKey(slaveId)) { val taskId: Int = slaveIdToTaskId.get(slaveId) taskIdToSlaveId.remove(taskId) removeExecutor(sparkExecutorId(slaveId, taskId.toString), reason) @@ -411,7 +411,7 @@ private[spark] class CoarseMesosSchedulerBackend( val slaveIdToTaskId = taskIdToSlaveId.inverse() for (executorId <- executorIds) { val slaveId = executorId.split("/")(0) - if (slaveIdToTaskId.contains(slaveId)) { + if (slaveIdToTaskId.containsKey(slaveId)) { mesosDriver.killTask( TaskID.newBuilder().setValue(slaveIdToTaskId.get(slaveId).toString).build()) pendingRemovedSlaveIds += slaveId diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala index 3efc536f1456c..e0c547dce6d07 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster.mesos -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.curator.framework.CuratorFramework import org.apache.zookeeper.CreateMode @@ -129,6 +129,6 @@ private[spark] class ZookeeperMesosClusterPersistenceEngine( } override def fetchAll[T](): Iterable[T] = { - zk.getChildren.forPath(WORKING_DIR).map(fetch[T]).flatten + zk.getChildren.forPath(WORKING_DIR).asScala.flatMap(fetch[T]) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 1206f184fbc82..07da9242b9922 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -21,7 +21,7 @@ import java.io.File import java.util.concurrent.locks.ReentrantLock import java.util.{Collections, Date, List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -350,7 +350,7 @@ private[spark] class MesosClusterScheduler( } // TODO: Page the status updates to avoid trying to reconcile // a large amount of tasks at once. - driver.reconcileTasks(statuses) + driver.reconcileTasks(statuses.toSeq.asJava) } } } @@ -493,10 +493,10 @@ private[spark] class MesosClusterScheduler( } override def resourceOffers(driver: SchedulerDriver, offers: JList[Offer]): Unit = { - val currentOffers = offers.map { o => + val currentOffers = offers.asScala.map(o => new ResourceOffer( o, getResource(o.getResourcesList, "cpus"), getResource(o.getResourcesList, "mem")) - }.toList + ).toList logTrace(s"Received offers from Mesos: \n${currentOffers.mkString("\n")}") val tasks = new mutable.HashMap[OfferID, ArrayBuffer[TaskInfo]]() val currentTime = new Date() @@ -521,10 +521,10 @@ private[spark] class MesosClusterScheduler( currentOffers, tasks) } - tasks.foreach { case (offerId, tasks) => - driver.launchTasks(Collections.singleton(offerId), tasks) + tasks.foreach { case (offerId, taskInfos) => + driver.launchTasks(Collections.singleton(offerId), taskInfos.asJava) } - offers + offers.asScala .filter(o => !tasks.keySet.contains(o.getId)) .foreach(o => driver.declineOffer(o.getId)) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 5c20606d58715..2e424054be785 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -20,7 +20,7 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File import java.util.{ArrayList => JArrayList, Collections, List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, HashSet} import org.apache.mesos.{Scheduler => MScheduler, _} @@ -129,14 +129,12 @@ private[spark] class MesosSchedulerBackend( val (resourcesAfterCpu, usedCpuResources) = partitionResources(availableResources, "cpus", scheduler.CPUS_PER_TASK) val (resourcesAfterMem, usedMemResources) = - partitionResources(resourcesAfterCpu, "mem", calculateTotalMemory(sc)) + partitionResources(resourcesAfterCpu.asJava, "mem", calculateTotalMemory(sc)) - builder.addAllResources(usedCpuResources) - builder.addAllResources(usedMemResources) + builder.addAllResources(usedCpuResources.asJava) + builder.addAllResources(usedMemResources.asJava) - sc.conf.getOption("spark.mesos.uris").map { uris => - setupUris(uris, command) - } + sc.conf.getOption("spark.mesos.uris").foreach(setupUris(_, command)) val executorInfo = builder .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) @@ -148,7 +146,7 @@ private[spark] class MesosSchedulerBackend( .setupContainerBuilderDockerInfo(image, sc.conf, executorInfo.getContainerBuilder()) } - (executorInfo.build(), resourcesAfterMem) + (executorInfo.build(), resourcesAfterMem.asJava) } /** @@ -193,7 +191,7 @@ private[spark] class MesosSchedulerBackend( private def getTasksSummary(tasks: JArrayList[MesosTaskInfo]): String = { val builder = new StringBuilder - tasks.foreach { t => + tasks.asScala.foreach { t => builder.append("Task id: ").append(t.getTaskId.getValue).append("\n") .append("Slave id: ").append(t.getSlaveId.getValue).append("\n") .append("Task resources: ").append(t.getResourcesList).append("\n") @@ -211,7 +209,7 @@ private[spark] class MesosSchedulerBackend( override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { inClassLoader() { // Fail-fast on offers we know will be rejected - val (usableOffers, unUsableOffers) = offers.partition { o => + val (usableOffers, unUsableOffers) = offers.asScala.partition { o => val mem = getResource(o.getResourcesList, "mem") val cpus = getResource(o.getResourcesList, "cpus") val slaveId = o.getSlaveId.getValue @@ -323,10 +321,10 @@ private[spark] class MesosSchedulerBackend( .setSlaveId(SlaveID.newBuilder().setValue(slaveId).build()) .setExecutor(executorInfo) .setName(task.name) - .addAllResources(cpuResources) + .addAllResources(cpuResources.asJava) .setData(MesosTaskLaunchData(task.serializedTask, task.attemptNumber).toByteString) .build() - (taskInfo, finalResources) + (taskInfo, finalResources.asJava) } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 5b854aa5c2754..860c8e097b3b9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.scheduler.cluster.mesos import java.util.{List => JList} import java.util.concurrent.CountDownLatch -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal @@ -137,7 +137,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { protected def getResource(res: JList[Resource], name: String): Double = { // A resource can have multiple values in the offer since it can either be from // a specific role or wildcard. - res.filter(_.getName == name).map(_.getScalar.getValue).sum + res.asScala.filter(_.getName == name).map(_.getScalar.getValue).sum } protected def markRegistered(): Unit = { @@ -169,7 +169,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { amountToUse: Double): (List[Resource], List[Resource]) = { var remain = amountToUse var requestedResources = new ArrayBuffer[Resource] - val remainingResources = resources.map { + val remainingResources = resources.asScala.map { case r => { if (remain > 0 && r.getType == Value.Type.SCALAR && @@ -214,7 +214,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { * @return */ protected def toAttributeMap(offerAttributes: JList[Attribute]): Map[String, GeneratedMessage] = { - offerAttributes.map(attr => { + offerAttributes.asScala.map(attr => { val attrValue = attr.getType match { case Value.Type.SCALAR => attr.getScalar case Value.Type.RANGES => attr.getRanges @@ -253,7 +253,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { requiredValues.map(_.toLong).exists(offerRange.contains(_)) case Some(offeredValue: Value.Set) => // check if the specified required values is a subset of offered set - requiredValues.subsetOf(offeredValue.getItemList.toSet) + requiredValues.subsetOf(offeredValue.getItemList.asScala.toSet) case Some(textValue: Value.Text) => // check if the specified value is equal, if multiple values are specified // we succeed if any of them match. @@ -299,14 +299,13 @@ private[mesos] trait MesosSchedulerUtils extends Logging { Map() } else { try { - Map() ++ mapAsScalaMap(splitter.split(constraintsVal)).map { - case (k, v) => - if (v == null || v.isEmpty) { - (k, Set[String]()) - } else { - (k, v.split(',').toSet) - } - } + splitter.split(constraintsVal).asScala.toMap.mapValues(v => + if (v == null || v.isEmpty) { + Set[String]() + } else { + v.split(',').toSet + } + ) } catch { case NonFatal(e) => throw new IllegalArgumentException(s"Bad constraint string: $constraintsVal", e) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 0ff7562e912ca..048a938507277 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -21,6 +21,7 @@ import java.io.{EOFException, IOException, InputStream, OutputStream} import java.nio.ByteBuffer import javax.annotation.Nullable +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import com.esotericsoftware.kryo.{Kryo, KryoException} @@ -373,16 +374,15 @@ private class JavaIterableWrapperSerializer override def read(kryo: Kryo, in: KryoInput, clz: Class[java.lang.Iterable[_]]) : java.lang.Iterable[_] = { kryo.readClassAndObject(in) match { - case scalaIterable: Iterable[_] => - scala.collection.JavaConversions.asJavaIterable(scalaIterable) - case javaIterable: java.lang.Iterable[_] => - javaIterable + case scalaIterable: Iterable[_] => scalaIterable.asJava + case javaIterable: java.lang.Iterable[_] => javaIterable } } } private object JavaIterableWrapperSerializer extends Logging { - // The class returned by asJavaIterable (scala.collection.convert.Wrappers$IterableWrapper). + // The class returned by JavaConverters.asJava + // (scala.collection.convert.Wrappers$IterableWrapper). val wrapperClass = scala.collection.convert.WrapAsJava.asJavaIterable(Seq(1)).getClass diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index f6a96d81e7aa9..c057de9b3f4df 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -21,7 +21,7 @@ import java.io.File import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.{Logging, SparkConf, SparkEnv} import org.apache.spark.executor.ShuffleWriteMetrics @@ -210,11 +210,13 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) shuffleStates.get(shuffleId) match { case Some(state) => if (consolidateShuffleFiles) { - for (fileGroup <- state.allFileGroups; file <- fileGroup.files) { + for (fileGroup <- state.allFileGroups.asScala; + file <- fileGroup.files) { file.delete() } } else { - for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) { + for (mapId <- state.completedMapTasks.asScala; + reduceId <- 0 until state.numBuckets) { val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) blockManager.diskBlockManager.getFile(blockId).delete() } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 6fec5240707a6..7db6035553ae6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -21,7 +21,7 @@ import java.util.{HashMap => JHashMap} import scala.collection.immutable.HashSet import scala.collection.mutable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, ThreadSafeRpcEndpoint} @@ -133,7 +133,7 @@ class BlockManagerMasterEndpoint( // Find all blocks for the given RDD, remove the block from both blockLocations and // the blockManagerInfo that is tracking the blocks. - val blocks = blockLocations.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) + val blocks = blockLocations.asScala.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) blocks.foreach { blockId => val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId) bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId))) @@ -242,7 +242,7 @@ class BlockManagerMasterEndpoint( private def storageStatus: Array[StorageStatus] = { blockManagerInfo.map { case (blockManagerId, info) => - new StorageStatus(blockManagerId, info.maxMem, info.blocks) + new StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala) }.toArray } @@ -292,7 +292,7 @@ class BlockManagerMasterEndpoint( if (askSlaves) { info.slaveEndpoint.ask[Seq[BlockId]](getMatchingBlockIds) } else { - Future { info.blocks.keys.filter(filter).toSeq } + Future { info.blocks.asScala.keys.filter(filter).toSeq } } future } diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 78e7ddc27d1c7..1738258a0c794 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.util -import scala.collection.JavaConversions.mapAsJavaMap +import scala.collection.JavaConverters._ import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem} import akka.pattern.ask @@ -92,7 +92,7 @@ private[spark] object AkkaUtils extends Logging { val akkaSslConfig = securityManager.akkaSSLOptions.createAkkaConfig .getOrElse(ConfigFactory.empty()) - val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap[String, String]) + val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap.asJava) .withFallback(akkaSslConfig).withFallback(ConfigFactory.parseString( s""" |akka.daemonic = on diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index a725767d08cc2..13cb516b583e9 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -19,12 +19,11 @@ package org.apache.spark.util import java.util.concurrent.CopyOnWriteArrayList -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.spark.Logging -import org.apache.spark.scheduler.SparkListener /** * An event bus which posts events to its listeners. @@ -46,7 +45,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { * `postToAll` in the same thread for all events. */ final def postToAll(event: E): Unit = { - // JavaConversions will create a JIterableWrapper if we use some Scala collection functions. + // JavaConverters can create a JIterableWrapper if we use asScala. // However, this method will be called frequently. To avoid the wrapper cost, here ewe use // Java Iterator directly. val iter = listeners.iterator @@ -69,7 +68,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { private[spark] def findListenersByClass[T <: L : ClassTag](): Seq[T] = { val c = implicitly[ClassTag[T]].runtimeClass - listeners.filter(_.getClass == c).map(_.asInstanceOf[T]).toSeq + listeners.asScala.filter(_.getClass == c).map(_.asInstanceOf[T]).toSeq } } diff --git a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala index 169489df6c1ea..a1c33212cdb2b 100644 --- a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala +++ b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala @@ -21,8 +21,6 @@ import java.net.{URLClassLoader, URL} import java.util.Enumeration import java.util.concurrent.ConcurrentHashMap -import scala.collection.JavaConversions._ - /** * URL class loader that exposes the `addURL` and `getURLs` methods in URLClassLoader. */ diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index 8de75ba9a9c92..d7e5143c30953 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -21,7 +21,8 @@ import java.util.Set import java.util.Map.Entry import java.util.concurrent.ConcurrentHashMap -import scala.collection.{JavaConversions, mutable} +import scala.collection.JavaConverters._ +import scala.collection.mutable import org.apache.spark.Logging @@ -50,8 +51,7 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa } def iterator: Iterator[(A, B)] = { - val jIterator = getEntrySet.iterator - JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue.value)) + getEntrySet.iterator.asScala.map(kv => (kv.getKey, kv.getValue.value)) } def getEntrySet: Set[Entry[A, TimeStampedValue[B]]] = internalMap.entrySet @@ -90,9 +90,7 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa } override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = { - JavaConversions.mapAsScalaConcurrentMap(internalMap) - .map { case (k, TimeStampedValue(v, t)) => (k, v) } - .filter(p) + internalMap.asScala.map { case (k, TimeStampedValue(v, t)) => (k, v) }.filter(p) } override def empty: mutable.Map[A, B] = new TimeStampedHashMap[A, B]() diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala index 7cd8f28b12dd6..65efeb1f4c19c 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala @@ -19,7 +19,7 @@ package org.apache.spark.util import java.util.concurrent.ConcurrentHashMap -import scala.collection.JavaConversions +import scala.collection.JavaConverters._ import scala.collection.mutable.Set private[spark] class TimeStampedHashSet[A] extends Set[A] { @@ -31,7 +31,7 @@ private[spark] class TimeStampedHashSet[A] extends Set[A] { def iterator: Iterator[A] = { val jIterator = internalMap.entrySet().iterator() - JavaConversions.asScalaIterator(jIterator).map(_.getKey) + jIterator.asScala.map(_.getKey) } override def + (elem: A): Set[A] = { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 8313312226713..2bab4af2e73ab 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -25,7 +25,7 @@ import java.util.{Properties, Locale, Random, UUID} import java.util.concurrent._ import javax.net.ssl.HttpsURLConnection -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.mutable.ArrayBuffer import scala.io.Source @@ -748,12 +748,12 @@ private[spark] object Utils extends Logging { // getNetworkInterfaces returns ifs in reverse order compared to ifconfig output order // on unix-like system. On windows, it returns in index order. // It's more proper to pick ip address following system output order. - val activeNetworkIFs = NetworkInterface.getNetworkInterfaces.toList + val activeNetworkIFs = NetworkInterface.getNetworkInterfaces.asScala.toSeq val reOrderedNetworkIFs = if (isWindows) activeNetworkIFs else activeNetworkIFs.reverse for (ni <- reOrderedNetworkIFs) { - val addresses = ni.getInetAddresses.toList - .filterNot(addr => addr.isLinkLocalAddress || addr.isLoopbackAddress) + val addresses = ni.getInetAddresses.asScala + .filterNot(addr => addr.isLinkLocalAddress || addr.isLoopbackAddress).toSeq if (addresses.nonEmpty) { val addr = addresses.find(_.isInstanceOf[Inet4Address]).getOrElse(addresses.head) // because of Inet6Address.toHostName may add interface at the end if it knows about it @@ -1498,10 +1498,8 @@ private[spark] object Utils extends Logging { * properties which have been set explicitly, as well as those for which only a default value * has been defined. */ def getSystemProperties: Map[String, String] = { - val sysProps = for (key <- System.getProperties.stringPropertyNames()) yield - (key, System.getProperty(key)) - - sysProps.toMap + System.getProperties.stringPropertyNames().asScala + .map(key => (key, System.getProperty(key))).toMap } /** @@ -1812,7 +1810,8 @@ private[spark] object Utils extends Logging { try { val properties = new Properties() properties.load(inReader) - properties.stringPropertyNames().map(k => (k, properties(k).trim)).toMap + properties.stringPropertyNames().asScala.map( + k => (k, properties.getProperty(k).trim)).toMap } catch { case e: IOException => throw new SparkException(s"Failed when loading Spark properties from $filename", e) @@ -1941,7 +1940,8 @@ private[spark] object Utils extends Logging { return true } isBindCollision(e.getCause) - case e: MultiException => e.getThrowables.exists(isBindCollision) + case e: MultiException => + e.getThrowables.asScala.exists(isBindCollision) case e: Exception => isBindCollision(e.getCause) case _ => false } diff --git a/core/src/main/scala/org/apache/spark/util/collection/Utils.scala b/core/src/main/scala/org/apache/spark/util/collection/Utils.scala index bdbca00a00622..4939b600dbfbd 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Utils.scala @@ -17,7 +17,7 @@ package org.apache.spark.util.collection -import scala.collection.JavaConversions.{collectionAsScalaIterable, asJavaIterator} +import scala.collection.JavaConverters._ import com.google.common.collect.{Ordering => GuavaOrdering} @@ -34,6 +34,6 @@ private[spark] object Utils { val ordering = new GuavaOrdering[T] { override def compare(l: T, r: T): Int = ord.compare(l, r) } - collectionAsScalaIterable(ordering.leastOf(asJavaIterator(input), num)).iterator + ordering.leastOf(input.asJava, num).iterator.asScala } } diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index ffe4b4baffb2a..ebd3d61ae7324 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -24,10 +24,10 @@ import java.util.*; import java.util.concurrent.*; -import scala.collection.JavaConversions; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; +import scala.collection.JavaConverters; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; @@ -1473,7 +1473,9 @@ public Integer call(Integer v1, Integer v2) throws Exception { Assert.assertEquals(expected, results); Partitioner defaultPartitioner = Partitioner.defaultPartitioner( - combinedRDD.rdd(), JavaConversions.asScalaBuffer(Lists.>newArrayList())); + combinedRDD.rdd(), + JavaConverters.collectionAsScalaIterableConverter( + Collections.>emptyList()).asScala().toSeq()); combinedRDD = originalRDD.keyBy(keyFunction) .combineByKey( createCombinerFunction, diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 90cb7da94e88a..ff9a92cc0a421 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark import java.util.concurrent.{TimeUnit, Executors} +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.{Try, Random} @@ -148,7 +149,6 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst } test("Thread safeness - SPARK-5425") { - import scala.collection.JavaConversions._ val executor = Executors.newSingleThreadScheduledExecutor() val sf = executor.scheduleAtFixedRate(new Runnable { override def run(): Unit = @@ -163,8 +163,9 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst } } finally { executor.shutdownNow() - for (key <- System.getProperties.stringPropertyNames() if key.startsWith("spark.5425.")) - System.getProperties.remove(key) + val sysProps = System.getProperties + for (key <- sysProps.stringPropertyNames().asScala if key.startsWith("spark.5425.")) + sysProps.remove(key) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala index cbd2aee10c0e2..86eb41dd7e5d7 100644 --- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.deploy import java.net.URL -import scala.collection.JavaConversions._ import scala.collection.mutable import scala.io.Source diff --git a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala index 47a64081e297e..1ed4bae3ca21e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala @@ -21,14 +21,14 @@ import java.io.{PrintStream, OutputStream, File} import java.net.URI import java.util.jar.Attributes.Name import java.util.jar.{JarFile, Manifest} -import java.util.zip.{ZipEntry, ZipFile} +import java.util.zip.ZipFile -import org.scalatest.BeforeAndAfterEach -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import com.google.common.io.Files import org.apache.commons.io.FileUtils +import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite import org.apache.spark.api.r.RUtils @@ -142,7 +142,7 @@ class RPackageUtilsSuite extends SparkFunSuite with BeforeAndAfterEach { IvyTestUtils.writeFile(fakePackageDir, "DESCRIPTION", "abc") val finalZip = RPackageUtils.zipRLibraries(tempDir, "sparkr.zip") assert(finalZip.exists()) - val entries = new ZipFile(finalZip).entries().toSeq.map(_.getName) + val entries = new ZipFile(finalZip).entries().asScala.map(_.getName).toSeq assert(entries.contains("/test.R")) assert(entries.contains("/SparkR/abc.R")) assert(entries.contains("/SparkR/DESCRIPTION")) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index bed6f3ea61241..98664dc1101e6 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -19,8 +19,6 @@ package org.apache.spark.deploy.worker import java.io.File -import scala.collection.JavaConversions._ - import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} @@ -36,6 +34,7 @@ class ExecutorRunnerTest extends SparkFunSuite { ExecutorState.RUNNING) val builder = CommandUtils.buildProcessBuilder( appDesc.command, new SecurityManager(conf), 512, sparkHome, er.substituteVariables) - assert(builder.command().last === appId) + val builderCommand = builder.command() + assert(builderCommand.get(builderCommand.size() - 1) === appId) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 730535ece7878..a9652d7e7d0b0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.scheduler import java.util.concurrent.Semaphore import scala.collection.mutable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.scalatest.Matchers @@ -365,10 +365,9 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match .set("spark.extraListeners", classOf[ListenerThatAcceptsSparkConf].getName + "," + classOf[BasicJobCounter].getName) sc = new SparkContext(conf) - sc.listenerBus.listeners.collect { case x: BasicJobCounter => x}.size should be (1) - sc.listenerBus.listeners.collect { - case x: ListenerThatAcceptsSparkConf => x - }.size should be (1) + sc.listenerBus.listeners.asScala.count(_.isInstanceOf[BasicJobCounter]) should be (1) + sc.listenerBus.listeners.asScala + .count(_.isInstanceOf[ListenerThatAcceptsSparkConf]) should be (1) } /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index 5ed30f64d705f..319b3173e7a6e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.scheduler.cluster.mesos import java.nio.ByteBuffer -import java.util +import java.util.Arrays +import java.util.Collection import java.util.Collections -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -61,7 +62,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi val mesosSchedulerBackend = new MesosSchedulerBackend(taskScheduler, sc, "master") - val resources = List( + val resources = Arrays.asList( mesosSchedulerBackend.createResource("cpus", 4), mesosSchedulerBackend.createResource("mem", 1024)) // uri is null. @@ -98,7 +99,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") val (execInfo, _) = backend.createExecutorInfo( - List(backend.createResource("cpus", 4)), "mockExecutor") + Arrays.asList(backend.createResource("cpus", 4)), "mockExecutor") assert(execInfo.getContainer.getDocker.getImage.equals("spark/mock")) val portmaps = execInfo.getContainer.getDocker.getPortMappingsList assert(portmaps.get(0).getHostPort.equals(80)) @@ -179,7 +180,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc))) when(taskScheduler.CPUS_PER_TASK).thenReturn(2) - val capture = ArgumentCaptor.forClass(classOf[util.Collection[TaskInfo]]) + val capture = ArgumentCaptor.forClass(classOf[Collection[TaskInfo]]) when( driver.launchTasks( Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), @@ -279,7 +280,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc))) when(taskScheduler.CPUS_PER_TASK).thenReturn(1) - val capture = ArgumentCaptor.forClass(classOf[util.Collection[TaskInfo]]) + val capture = ArgumentCaptor.forClass(classOf[Collection[TaskInfo]]) when( driver.launchTasks( Matchers.eq(Collections.singleton(mesosOffers.get(0).getId)), @@ -304,7 +305,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi assert(cpusDev.getName.equals("cpus")) assert(cpusDev.getScalar.getValue.equals(1.0)) assert(cpusDev.getRole.equals("dev")) - val executorResources = taskInfo.getExecutor.getResourcesList + val executorResources = taskInfo.getExecutor.getResourcesList.asScala assert(executorResources.exists { r => r.getName.equals("mem") && r.getScalar.getValue.equals(484.0) && r.getRole.equals("prod") }) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 23a1fdb0f5009..8d1c9d17e977e 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.serializer import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.ClassTag @@ -173,7 +174,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { test("asJavaIterable") { // Serialize a collection wrapped by asJavaIterable val ser = new KryoSerializer(conf).newInstance() - val a = ser.serialize(scala.collection.convert.WrapAsJava.asJavaIterable(Seq(12345))) + val a = ser.serialize(Seq(12345).asJava) val b = ser.deserialize[java.lang.Iterable[Int]](a) assert(b.iterator().next() === 12345) diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 69888b2694bae..22e30ecaf0533 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -21,7 +21,6 @@ import java.net.{HttpURLConnection, URL} import javax.servlet.http.{HttpServletResponse, HttpServletRequest} import scala.io.Source -import scala.collection.JavaConversions._ import scala.xml.Node import com.gargoylesoftware.htmlunit.DefaultCssErrorHandler @@ -341,15 +340,15 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B // The completed jobs table should have two rows. The first row will be the most recent job: val firstRow = find(cssSelector("tbody tr")).get.underlying val firstRowColumns = firstRow.findElements(By.tagName("td")) - firstRowColumns(0).getText should be ("1") - firstRowColumns(4).getText should be ("1/1 (2 skipped)") - firstRowColumns(5).getText should be ("8/8 (16 skipped)") + firstRowColumns.get(0).getText should be ("1") + firstRowColumns.get(4).getText should be ("1/1 (2 skipped)") + firstRowColumns.get(5).getText should be ("8/8 (16 skipped)") // The second row is the first run of the job, where nothing was skipped: val secondRow = findAll(cssSelector("tbody tr")).toSeq(1).underlying val secondRowColumns = secondRow.findElements(By.tagName("td")) - secondRowColumns(0).getText should be ("0") - secondRowColumns(4).getText should be ("3/3") - secondRowColumns(5).getText should be ("24/24") + secondRowColumns.get(0).getText should be ("0") + secondRowColumns.get(4).getText should be ("3/3") + secondRowColumns.get(5).getText should be ("24/24") } } } @@ -502,8 +501,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B for { (row, idx) <- rows.zipWithIndex columns = row.findElements(By.tagName("td")) - id = columns(0).getText() - description = columns(1).getText() + id = columns.get(0).getText() + description = columns.get(1).getText() } { id should be (expJobInfo(idx)._1) description should include (expJobInfo(idx)._2) @@ -547,8 +546,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B for { (row, idx) <- rows.zipWithIndex columns = row.findElements(By.tagName("td")) - id = columns(0).getText() - description = columns(1).getText() + id = columns.get(0).getText() + description = columns.get(1).getText() } { id should be (expStageInfo(idx)._1) description should include (expStageInfo(idx)._2) diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index 36832f51d2ad4..fa07c1e5017cd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -19,10 +19,7 @@ package org.apache.spark.examples import java.nio.ByteBuffer - -import scala.collection.JavaConversions._ -import scala.collection.mutable.ListBuffer -import scala.collection.immutable.Map +import java.util.Collections import org.apache.cassandra.hadoop.ConfigHelper import org.apache.cassandra.hadoop.cql3.CqlPagingInputFormat @@ -32,7 +29,6 @@ import org.apache.cassandra.utils.ByteBufferUtil import org.apache.hadoop.mapreduce.Job import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ /* @@ -121,12 +117,9 @@ object CassandraCQLTest { val casoutputCF = aggregatedRDD.map { case (productId, saleCount) => { - val outColFamKey = Map("prod_id" -> ByteBufferUtil.bytes(productId)) - val outKey: java.util.Map[String, ByteBuffer] = outColFamKey - var outColFamVal = new ListBuffer[ByteBuffer] - outColFamVal += ByteBufferUtil.bytes(saleCount) - val outVal: java.util.List[ByteBuffer] = outColFamVal - (outKey, outVal) + val outKey = Collections.singletonMap("prod_id", ByteBufferUtil.bytes(productId)) + val outVal = Collections.singletonList(ByteBufferUtil.bytes(saleCount)) + (outKey, outVal) } } diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala index 96ef3e198e380..2e56d24c60c33 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -19,10 +19,9 @@ package org.apache.spark.examples import java.nio.ByteBuffer +import java.util.Arrays import java.util.SortedMap -import scala.collection.JavaConversions._ - import org.apache.cassandra.db.IColumn import org.apache.cassandra.hadoop.ColumnFamilyOutputFormat import org.apache.cassandra.hadoop.ConfigHelper @@ -32,7 +31,6 @@ import org.apache.cassandra.utils.ByteBufferUtil import org.apache.hadoop.mapreduce.Job import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ /* * This example demonstrates using Spark with Cassandra with the New Hadoop API and Cassandra @@ -118,7 +116,7 @@ object CassandraTest { val outputkey = ByteBufferUtil.bytes(word + "-COUNT-" + System.currentTimeMillis) - val mutations: java.util.List[Mutation] = new Mutation() :: new Mutation() :: Nil + val mutations = Arrays.asList(new Mutation(), new Mutation()) mutations.get(0).setColumn_or_supercolumn(new ColumnOrSuperColumn()) mutations.get(0).column_or_supercolumn.setColumn(colWord) mutations.get(1).setColumn_or_supercolumn(new ColumnOrSuperColumn()) diff --git a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala index c42df2b8845d2..bec61f3cd4296 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala @@ -18,7 +18,7 @@ // scalastyle:off println package org.apache.spark.examples -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.util.Utils @@ -36,10 +36,10 @@ object DriverSubmissionTest { val properties = Utils.getSystemProperties println("Environment variables containing SPARK_TEST:") - env.filter{case (k, v) => k.contains("SPARK_TEST")}.foreach(println) + env.asScala.filter { case (k, _) => k.contains("SPARK_TEST")}.foreach(println) println("System properties containing spark.test:") - properties.filter{case (k, v) => k.toString.contains("spark.test")}.foreach(println) + properties.filter { case (k, _) => k.toString.contains("spark.test") }.foreach(println) for (i <- 1 until numSecondsToSleep) { println(s"Alive for $i out of $numSecondsToSleep seconds") diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala index 3ebb112fc069e..805184e740f06 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala @@ -19,7 +19,7 @@ package org.apache.spark.examples.pythonconverters import java.util.{Collection => JCollection, Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.avro.generic.{GenericFixed, IndexedRecord} import org.apache.avro.mapred.AvroWrapper @@ -58,7 +58,7 @@ object AvroConversionUtil extends Serializable { val map = new java.util.HashMap[String, Any] obj match { case record: IndexedRecord => - record.getSchema.getFields.zipWithIndex.foreach { case (f, i) => + record.getSchema.getFields.asScala.zipWithIndex.foreach { case (f, i) => map.put(f.name, fromAvro(record.get(i), f.schema)) } case other => throw new SparkException( @@ -68,9 +68,9 @@ object AvroConversionUtil extends Serializable { } def unpackMap(obj: Any, schema: Schema): JMap[String, Any] = { - obj.asInstanceOf[JMap[_, _]].map { case (key, value) => + obj.asInstanceOf[JMap[_, _]].asScala.map { case (key, value) => (key.toString, fromAvro(value, schema.getValueType)) - } + }.asJava } def unpackFixed(obj: Any, schema: Schema): Array[Byte] = { @@ -91,17 +91,17 @@ object AvroConversionUtil extends Serializable { def unpackArray(obj: Any, schema: Schema): JCollection[Any] = obj match { case c: JCollection[_] => - c.map(fromAvro(_, schema.getElementType)) + c.asScala.map(fromAvro(_, schema.getElementType)).toSeq.asJava case arr: Array[_] if arr.getClass.getComponentType.isPrimitive => - arr.toSeq + arr.toSeq.asJava.asInstanceOf[JCollection[Any]] case arr: Array[_] => - arr.map(fromAvro(_, schema.getElementType)).toSeq + arr.map(fromAvro(_, schema.getElementType)).toSeq.asJava case other => throw new SparkException( s"Unknown ARRAY type ${other.getClass.getName}") } def unpackUnion(obj: Any, schema: Schema): Any = { - schema.getTypes.toList match { + schema.getTypes.asScala.toList match { case List(s) => fromAvro(obj, s) case List(n, s) if n.getType == NULL => fromAvro(obj, s) case List(s, n) if n.getType == NULL => fromAvro(obj, s) diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala index 83feb5703b908..00ce47af4813d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala @@ -17,11 +17,13 @@ package org.apache.spark.examples.pythonconverters -import org.apache.spark.api.python.Converter import java.nio.ByteBuffer + +import scala.collection.JavaConverters._ + import org.apache.cassandra.utils.ByteBufferUtil -import collection.JavaConversions._ +import org.apache.spark.api.python.Converter /** * Implementation of [[org.apache.spark.api.python.Converter]] that converts Cassandra @@ -30,7 +32,7 @@ import collection.JavaConversions._ class CassandraCQLKeyConverter extends Converter[Any, java.util.Map[String, Int]] { override def convert(obj: Any): java.util.Map[String, Int] = { val result = obj.asInstanceOf[java.util.Map[String, ByteBuffer]] - mapAsJavaMap(result.mapValues(bb => ByteBufferUtil.toInt(bb))) + result.asScala.mapValues(ByteBufferUtil.toInt).asJava } } @@ -41,7 +43,7 @@ class CassandraCQLKeyConverter extends Converter[Any, java.util.Map[String, Int] class CassandraCQLValueConverter extends Converter[Any, java.util.Map[String, String]] { override def convert(obj: Any): java.util.Map[String, String] = { val result = obj.asInstanceOf[java.util.Map[String, ByteBuffer]] - mapAsJavaMap(result.mapValues(bb => ByteBufferUtil.string(bb))) + result.asScala.mapValues(ByteBufferUtil.string).asJava } } @@ -52,7 +54,7 @@ class CassandraCQLValueConverter extends Converter[Any, java.util.Map[String, St class ToCassandraCQLKeyConverter extends Converter[Any, java.util.Map[String, ByteBuffer]] { override def convert(obj: Any): java.util.Map[String, ByteBuffer] = { val input = obj.asInstanceOf[java.util.Map[String, Int]] - mapAsJavaMap(input.mapValues(i => ByteBufferUtil.bytes(i))) + input.asScala.mapValues(ByteBufferUtil.bytes).asJava } } @@ -63,6 +65,6 @@ class ToCassandraCQLKeyConverter extends Converter[Any, java.util.Map[String, By class ToCassandraCQLValueConverter extends Converter[Any, java.util.List[ByteBuffer]] { override def convert(obj: Any): java.util.List[ByteBuffer] = { val input = obj.asInstanceOf[java.util.List[String]] - seqAsJavaList(input.map(s => ByteBufferUtil.bytes(s))) + input.asScala.map(ByteBufferUtil.bytes).asJava } } diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala index 90d48a64106c7..0a25ee7ae56f4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala @@ -17,7 +17,7 @@ package org.apache.spark.examples.pythonconverters -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.parsing.json.JSONObject import org.apache.spark.api.python.Converter @@ -33,7 +33,6 @@ import org.apache.hadoop.hbase.CellUtil */ class HBaseResultToStringConverter extends Converter[Any, String] { override def convert(obj: Any): String = { - import collection.JavaConverters._ val result = obj.asInstanceOf[Result] val output = result.listCells.asScala.map(cell => Map( @@ -77,7 +76,7 @@ class StringToImmutableBytesWritableConverter extends Converter[Any, ImmutableBy */ class StringListToPutConverter extends Converter[Any, Put] { override def convert(obj: Any): Put = { - val output = obj.asInstanceOf[java.util.ArrayList[String]].map(Bytes.toBytes(_)).toArray + val output = obj.asInstanceOf[java.util.ArrayList[String]].asScala.map(Bytes.toBytes).toArray val put = new Put(output(0)) put.add(output(1), output(2), output(3)) } diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index fa43629d49771..d2654700ea729 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -20,7 +20,7 @@ import java.net.InetSocketAddress import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.{TimeUnit, CountDownLatch, Executors} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success} @@ -166,7 +166,7 @@ class SparkSinkSuite extends FunSuite { channelContext.put("capacity", channelCapacity.toString) channelContext.put("transactionCapacity", 1000.toString) channelContext.put("keep-alive", 0.toString) - channelContext.putAll(overrides) + channelContext.putAll(overrides.asJava) channel.setName(scala.util.Random.nextString(10)) channel.configure(channelContext) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala index 65c49c131518b..48df27b26867f 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.flume import java.io.{ObjectOutput, ObjectInput} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.util.Utils import org.apache.spark.Logging @@ -60,7 +60,7 @@ private[streaming] object EventTransformer extends Logging { out.write(body) val numHeaders = headers.size() out.writeInt(numHeaders) - for ((k, v) <- headers) { + for ((k, v) <- headers.asScala) { val keyBuff = Utils.serialize(k.toString) out.writeInt(keyBuff.length) out.write(keyBuff) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala index 88cc2aa3bf022..b9d4e762ca05d 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeBatchFetcher.scala @@ -16,7 +16,6 @@ */ package org.apache.spark.streaming.flume -import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import com.google.common.base.Throwables @@ -155,7 +154,7 @@ private[flume] class FlumeBatchFetcher(receiver: FlumePollingReceiver) extends R val buffer = new ArrayBuffer[SparkFlumeEvent](events.size()) var j = 0 while (j < events.size()) { - val event = events(j) + val event = events.get(j) val sparkFlumeEvent = new SparkFlumeEvent() sparkFlumeEvent.event.setBody(event.getBody) sparkFlumeEvent.event.setHeaders(event.getHeaders) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 1e32a365a1eee..2bf99cb3cba1f 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -22,7 +22,7 @@ import java.io.{ObjectInput, ObjectOutput, Externalizable} import java.nio.ByteBuffer import java.util.concurrent.Executors -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.flume.source.avro.AvroSourceProtocol @@ -99,7 +99,7 @@ class SparkFlumeEvent() extends Externalizable { val numHeaders = event.getHeaders.size() out.writeInt(numHeaders) - for ((k, v) <- event.getHeaders) { + for ((k, v) <- event.getHeaders.asScala) { val keyBuff = Utils.serialize(k.toString) out.writeInt(keyBuff.length) out.write(keyBuff) @@ -127,8 +127,7 @@ class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol { } override def appendBatch(events : java.util.List[AvroFlumeEvent]) : Status = { - events.foreach (event => - receiver.store(SparkFlumeEvent.fromAvroFlumeEvent(event))) + events.asScala.foreach(event => receiver.store(SparkFlumeEvent.fromAvroFlumeEvent(event))) Status.OK } } diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala index 583e7dca317ad..0bc46209b8369 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress import java.util.concurrent.{LinkedBlockingQueue, Executors} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import com.google.common.util.concurrent.ThreadFactoryBuilder @@ -94,9 +94,7 @@ private[streaming] class FlumePollingReceiver( override def onStop(): Unit = { logInfo("Shutting down Flume Polling Receiver") receiverExecutor.shutdownNow() - connections.foreach(connection => { - connection.transceiver.close() - }) + connections.asScala.foreach(_.transceiver.close()) channelFactory.releaseExternalResources() } diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala index 9d9c3b189415f..70018c86f92be 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala @@ -19,9 +19,9 @@ package org.apache.spark.streaming.flume import java.net.{InetSocketAddress, ServerSocket} import java.nio.ByteBuffer -import java.util.{List => JList} +import java.util.Collections -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import com.google.common.base.Charsets.UTF_8 import org.apache.avro.ipc.NettyTransceiver @@ -59,13 +59,13 @@ private[flume] class FlumeTestUtils { } /** Send data to the flume receiver */ - def writeInput(input: JList[String], enableCompression: Boolean): Unit = { + def writeInput(input: Seq[String], enableCompression: Boolean): Unit = { val testAddress = new InetSocketAddress("localhost", testPort) val inputEvents = input.map { item => val event = new AvroFlumeEvent event.setBody(ByteBuffer.wrap(item.getBytes(UTF_8))) - event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) + event.setHeaders(Collections.singletonMap("test", "header")) event } @@ -88,7 +88,7 @@ private[flume] class FlumeTestUtils { } // Send data - val status = client.appendBatch(inputEvents.toList) + val status = client.appendBatch(inputEvents.asJava) if (status != avro.Status.OK) { throw new AssertionError("Sent events unsuccessfully") } diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index a65a9b921aafa..c719b80aca7ed 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -21,7 +21,7 @@ import java.net.InetSocketAddress import java.io.{DataOutputStream, ByteArrayOutputStream} import java.util.{List => JList, Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.api.java.function.PairFunction import org.apache.spark.api.python.PythonRDD @@ -268,8 +268,8 @@ private[flume] class FlumeUtilsPythonHelper { maxBatchSize: Int, parallelism: Int ): JavaPairDStream[Array[Byte], Array[Byte]] = { - assert(hosts.length == ports.length) - val addresses = hosts.zip(ports).map { + assert(hosts.size() == ports.size()) + val addresses = hosts.asScala.zip(ports.asScala).map { case (host, port) => new InetSocketAddress(host, port) } val dstream = FlumeUtils.createPollingStream( @@ -286,7 +286,7 @@ private object FlumeUtilsPythonHelper { val output = new DataOutputStream(byteStream) try { output.writeInt(map.size) - map.foreach { kv => + map.asScala.foreach { kv => PythonRDD.writeUTF(kv._1.toString, output) PythonRDD.writeUTF(kv._2.toString, output) } diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala index 91d63d49dbec3..a2ab320957db3 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala @@ -18,9 +18,8 @@ package org.apache.spark.streaming.flume import java.util.concurrent._ -import java.util.{List => JList, Map => JMap} +import java.util.{Map => JMap, Collections} -import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import com.google.common.base.Charsets.UTF_8 @@ -77,7 +76,7 @@ private[flume] class PollingFlumeTestUtils { /** * Start 2 sinks and return the ports */ - def startMultipleSinks(): JList[Int] = { + def startMultipleSinks(): Seq[Int] = { channels.clear() sinks.clear() @@ -138,8 +137,7 @@ private[flume] class PollingFlumeTestUtils { /** * A Python-friendly method to assert the output */ - def assertOutput( - outputHeaders: JList[JMap[String, String]], outputBodies: JList[String]): Unit = { + def assertOutput(outputHeaders: Seq[JMap[String, String]], outputBodies: Seq[String]): Unit = { require(outputHeaders.size == outputBodies.size) val eventSize = outputHeaders.size if (eventSize != totalEventsPerChannel * channels.size) { @@ -149,12 +147,12 @@ private[flume] class PollingFlumeTestUtils { var counter = 0 for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { val eventBodyToVerify = s"${channels(k).getName}-$i" - val eventHeaderToVerify: JMap[String, String] = Map[String, String](s"test-$i" -> "header") + val eventHeaderToVerify: JMap[String, String] = Collections.singletonMap(s"test-$i", "header") var found = false var j = 0 while (j < eventSize && !found) { - if (eventBodyToVerify == outputBodies.get(j) && - eventHeaderToVerify == outputHeaders.get(j)) { + if (eventBodyToVerify == outputBodies(j) && + eventHeaderToVerify == outputHeaders(j)) { found = true counter += 1 } @@ -195,7 +193,7 @@ private[flume] class PollingFlumeTestUtils { tx.begin() for (j <- 0 until eventsPerBatch) { channel.put(EventBuilder.withBody(s"${channel.getName}-$t".getBytes(UTF_8), - Map[String, String](s"test-$t" -> "header"))) + Collections.singletonMap(s"test-$t", "header"))) t += 1 } tx.commit() diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index d5f9a0aa38f9f..ff2fb8eed204c 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} import scala.concurrent.duration._ import scala.language.postfixOps @@ -116,9 +116,9 @@ class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Log // The eventually is required to ensure that all data in the batch has been processed. eventually(timeout(10 seconds), interval(100 milliseconds)) { val flattenOutputBuffer = outputBuffer.flatten - val headers = flattenOutputBuffer.map(_.event.getHeaders.map { - case kv => (kv._1.toString, kv._2.toString) - }).map(mapAsJavaMap) + val headers = flattenOutputBuffer.map(_.event.getHeaders.asScala.map { + case (key, value) => (key.toString, value.toString) + }).map(_.asJava) val bodies = flattenOutputBuffer.map(e => new String(e.event.getBody.array(), UTF_8)) utils.assertOutput(headers, bodies) } diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 5bc4cdf65306c..5ffb60bd602f9 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.flume -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.duration._ import scala.language.postfixOps diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala index 79a9db4291bef..c9fd715d3d554 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala @@ -24,6 +24,7 @@ import java.util.concurrent.TimeoutException import java.util.{Map => JMap, Properties} import scala.annotation.tailrec +import scala.collection.JavaConverters._ import scala.language.postfixOps import scala.util.control.NonFatal @@ -159,8 +160,7 @@ private[kafka] class KafkaTestUtils extends Logging { /** Java-friendly function for sending messages to the Kafka broker */ def sendMessages(topic: String, messageToFreq: JMap[String, JInt]): Unit = { - import scala.collection.JavaConversions._ - sendMessages(topic, Map(messageToFreq.mapValues(_.intValue()).toSeq: _*)) + sendMessages(topic, Map(messageToFreq.asScala.mapValues(_.intValue()).toSeq: _*)) } /** Send the messages to the Kafka broker */ diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 388dbb8184106..3128222077537 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.kafka import java.lang.{Integer => JInt, Long => JLong} import java.util.{List => JList, Map => JMap, Set => JSet} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import kafka.common.TopicAndPartition @@ -96,7 +96,7 @@ object KafkaUtils { groupId: String, topics: JMap[String, JInt] ): JavaPairReceiverInputDStream[String, String] = { - createStream(jssc.ssc, zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*)) + createStream(jssc.ssc, zkQuorum, groupId, Map(topics.asScala.mapValues(_.intValue()).toSeq: _*)) } /** @@ -115,7 +115,7 @@ object KafkaUtils { topics: JMap[String, JInt], storageLevel: StorageLevel ): JavaPairReceiverInputDStream[String, String] = { - createStream(jssc.ssc, zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), + createStream(jssc.ssc, zkQuorum, groupId, Map(topics.asScala.mapValues(_.intValue()).toSeq: _*), storageLevel) } @@ -149,7 +149,10 @@ object KafkaUtils { implicit val valueCmd: ClassTag[T] = ClassTag(valueDecoderClass) createStream[K, V, U, T]( - jssc.ssc, kafkaParams.toMap, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel) + jssc.ssc, + kafkaParams.asScala.toMap, + Map(topics.asScala.mapValues(_.intValue()).toSeq: _*), + storageLevel) } /** get leaders for the given offset ranges, or throw an exception */ @@ -275,7 +278,7 @@ object KafkaUtils { implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) new JavaPairRDD(createRDD[K, V, KD, VD]( - jsc.sc, Map(kafkaParams.toSeq: _*), offsetRanges)) + jsc.sc, Map(kafkaParams.asScala.toSeq: _*), offsetRanges)) } /** @@ -311,9 +314,9 @@ object KafkaUtils { implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) implicit val recordCmt: ClassTag[R] = ClassTag(recordClass) - val leaderMap = Map(leaders.toSeq: _*) + val leaderMap = Map(leaders.asScala.toSeq: _*) createRDD[K, V, KD, VD, R]( - jsc.sc, Map(kafkaParams.toSeq: _*), offsetRanges, leaderMap, messageHandler.call _) + jsc.sc, Map(kafkaParams.asScala.toSeq: _*), offsetRanges, leaderMap, messageHandler.call(_)) } /** @@ -476,8 +479,8 @@ object KafkaUtils { val cleanedHandler = jssc.sparkContext.clean(messageHandler.call _) createDirectStream[K, V, KD, VD, R]( jssc.ssc, - Map(kafkaParams.toSeq: _*), - Map(fromOffsets.mapValues { _.longValue() }.toSeq: _*), + Map(kafkaParams.asScala.toSeq: _*), + Map(fromOffsets.asScala.mapValues(_.longValue()).toSeq: _*), cleanedHandler ) } @@ -531,8 +534,8 @@ object KafkaUtils { implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) createDirectStream[K, V, KD, VD]( jssc.ssc, - Map(kafkaParams.toSeq: _*), - Set(topics.toSeq: _*) + Map(kafkaParams.asScala.toSeq: _*), + Set(topics.asScala.toSeq: _*) ) } } @@ -602,10 +605,10 @@ private[kafka] class KafkaUtilsPythonHelper { ): JavaPairInputDStream[Array[Byte], Array[Byte]] = { if (!fromOffsets.isEmpty) { - import scala.collection.JavaConversions._ - val topicsFromOffsets = fromOffsets.keySet().map(_.topic) - if (topicsFromOffsets != topics.toSet) { - throw new IllegalStateException(s"The specified topics: ${topics.toSet.mkString(" ")} " + + val topicsFromOffsets = fromOffsets.keySet().asScala.map(_.topic) + if (topicsFromOffsets != topics.asScala.toSet) { + throw new IllegalStateException( + s"The specified topics: ${topics.asScala.toSet.mkString(" ")} " + s"do not equal to the topic from offsets: ${topicsFromOffsets.mkString(" ")}") } } @@ -663,6 +666,6 @@ private[kafka] class KafkaUtilsPythonHelper { "with this RDD, please call this method only on a Kafka RDD.") val kafkaRDD = kafkaRDDs.head.asInstanceOf[KafkaRDD[_, _, _, _, _]] - kafkaRDD.offsetRanges.toSeq + kafkaRDD.offsetRanges.toSeq.asJava } } diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala index 0469d0af8864a..4ea218eaa4de1 100644 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala +++ b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala @@ -18,15 +18,17 @@ package org.apache.spark.streaming.zeromq import scala.reflect.ClassTag -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ + import akka.actor.{Props, SupervisorStrategy} import akka.util.ByteString import akka.zeromq.Subscribe + import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} -import org.apache.spark.streaming.dstream.{ReceiverInputDStream} +import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.receiver.ActorSupervisorStrategy object ZeroMQUtils { @@ -75,7 +77,8 @@ object ZeroMQUtils { ): JavaReceiverInputDStream[T] = { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val fn = (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).toIterator + val fn = + (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).iterator().asScala createStream[T](jssc.ssc, publisherUrl, subscribe, fn, storageLevel, supervisorStrategy) } @@ -99,7 +102,8 @@ object ZeroMQUtils { ): JavaReceiverInputDStream[T] = { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val fn = (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).toIterator + val fn = + (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).iterator().asScala createStream[T](jssc.ssc, publisherUrl, subscribe, fn, storageLevel) } @@ -122,7 +126,8 @@ object ZeroMQUtils { ): JavaReceiverInputDStream[T] = { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val fn = (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).toIterator + val fn = + (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).iterator().asScala createStream[T](jssc.ssc, publisherUrl, subscribe, fn) } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala index a003ddf325e6e..5d32fa699ae5b 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.kinesis -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.control.NonFatal import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} @@ -213,7 +213,7 @@ class KinesisSequenceRangeIterator( s"getting records using shard iterator") { client.getRecords(getRecordsRequest) } - (getRecordsResult.getRecords.iterator(), getRecordsResult.getNextShardIterator) + (getRecordsResult.getRecords.iterator().asScala, getRecordsResult.getNextShardIterator) } /** diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 22324e821ce94..6e0988c1af8a1 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.kinesis import java.util.UUID -import scala.collection.JavaConversions.asScalaIterator +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.control.NonFatal @@ -202,7 +202,7 @@ private[kinesis] class KinesisReceiver( /** Add records of the given shard to the current block being generated */ private[kinesis] def addRecords(shardId: String, records: java.util.List[Record]): Unit = { if (records.size > 0) { - val dataIterator = records.iterator().map { record => + val dataIterator = records.iterator().asScala.map { record => val byteBuffer = record.getData() val byteArray = new Array[Byte](byteBuffer.remaining()) byteBuffer.get(byteArray) diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index c8eec13ec7dc7..634bf94521079 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.streaming.kinesis import java.nio.ByteBuffer import java.util.concurrent.TimeUnit +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.{Failure, Random, Success, Try} @@ -115,7 +116,7 @@ private[kinesis] class KinesisTestUtils extends Logging { * Expose a Python friendly API. */ def pushData(testData: java.util.List[Int]): Unit = { - pushData(scala.collection.JavaConversions.asScalaBuffer(testData)) + pushData(testData.asScala) } def deleteStream(): Unit = { diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index ceb135e0651aa..3d136aec2e702 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.streaming.kinesis import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets +import java.util.Arrays -import scala.collection.JavaConversions.seqAsJavaList - -import com.amazonaws.services.kinesis.clientlibrary.exceptions.{InvalidStateException, KinesisClientLibDependencyException, ShutdownException, ThrottlingException} +import com.amazonaws.services.kinesis.clientlibrary.exceptions._ import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason import com.amazonaws.services.kinesis.model.Record @@ -47,10 +47,10 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft val someSeqNum = Some(seqNum) val record1 = new Record() - record1.setData(ByteBuffer.wrap("Spark In Action".getBytes())) + record1.setData(ByteBuffer.wrap("Spark In Action".getBytes(StandardCharsets.UTF_8))) val record2 = new Record() - record2.setData(ByteBuffer.wrap("Learning Spark".getBytes())) - val batch = List[Record](record1, record2) + record2.setData(ByteBuffer.wrap("Learning Spark".getBytes(StandardCharsets.UTF_8))) + val batch = Arrays.asList(record1, record2) var receiverMock: KinesisReceiver = _ var checkpointerMock: IRecordProcessorCheckpointer = _ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index 87eeb5db05d26..7a1c7796065ee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.util -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.Random import com.github.fommil.netlib.BLAS.{getInstance => blas} @@ -52,7 +52,7 @@ object LinearDataGenerator { nPoints: Int, seed: Int, eps: Double): java.util.List[LabeledPoint] = { - seqAsJavaList(generateLinearInput(intercept, weights, nPoints, seed, eps)) + generateLinearInput(intercept, weights, nPoints, seed, eps).asJava } /** diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java index a1ee554152372..2744e020e9e49 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java @@ -20,7 +20,7 @@ import java.io.Serializable; import java.util.List; -import static scala.collection.JavaConversions.seqAsJavaList; +import scala.collection.JavaConverters; import org.junit.After; import org.junit.Assert; @@ -55,8 +55,9 @@ public void setUp() { double[] xMean = {5.843, 3.057, 3.758, 1.199}; double[] xVariance = {0.6856, 0.1899, 3.116, 0.581}; - List points = seqAsJavaList(generateMultinomialLogisticInput( - weights, xMean, xVariance, true, nPoints, 42)); + List points = JavaConverters.asJavaListConverter( + generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42) + ).asJava(); datasetRDD = jsc.parallelize(points, 2); dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 2473510e13514..8d14bb6572155 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.classification -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.Random import scala.util.control.Breaks._ @@ -38,7 +38,7 @@ object LogisticRegressionSuite { scale: Double, nPoints: Int, seed: Int): java.util.List[LabeledPoint] = { - seqAsJavaList(generateLogisticInput(offset, scale, nPoints, seed)) + generateLogisticInput(offset, scale, nPoints, seed).asJava } // Generate input of the form Y = logistic(offset + scale*X) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index b1d78cba9e3dc..ee3c85d09a463 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.classification -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.Random import org.jblas.DoubleMatrix @@ -35,7 +35,7 @@ object SVMSuite { weights: Array[Double], nPoints: Int, seed: Int): java.util.List[LabeledPoint] = { - seqAsJavaList(generateSVMInput(intercept, weights, nPoints, seed)) + generateSVMInput(intercept, weights, nPoints, seed).asJava } // Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index 13b754a03943a..36ac7d267243d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.optimization -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.Random import org.scalatest.Matchers @@ -35,7 +35,7 @@ object GradientDescentSuite { scale: Double, nPoints: Int, seed: Int): java.util.List[LabeledPoint] = { - seqAsJavaList(generateGDInput(offset, scale, nPoints, seed)) + generateGDInput(offset, scale, nPoints, seed).asJava } // Generate input of the form Y = logistic(offset + scale * X) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index 05b87728d6fdb..045135f7f8d60 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.recommendation -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.math.abs import scala.util.Random @@ -38,7 +38,7 @@ object ALSSuite { negativeWeights: Boolean): (java.util.List[Rating], DoubleMatrix, DoubleMatrix) = { val (sampledRatings, trueRatings, truePrefs) = generateRatings(users, products, features, samplingRate, implicitPrefs) - (seqAsJavaList(sampledRatings), trueRatings, truePrefs) + (sampledRatings.asJava, trueRatings, truePrefs) } def generateRatings( diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 04e0d49b178cf..ea52bfd67944a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -18,13 +18,13 @@ import java.io._ import scala.util.Properties -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import sbt._ import sbt.Classpaths.publishTask import sbt.Keys._ import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion -import com.typesafe.sbt.pom.{loadEffectivePom, PomBuild, SbtPomKeys} +import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys} import net.virtualvoid.sbt.graph.Plugin.graphSettings import spray.revolver.RevolverPlugin._ @@ -120,7 +120,7 @@ object SparkBuild extends PomBuild { case _ => } - override val userPropertiesMap = System.getProperties.toMap + override val userPropertiesMap = System.getProperties.asScala.toMap lazy val MavenCompile = config("m2r") extend(Compile) lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy") @@ -559,7 +559,7 @@ object TestSettings { javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test += "-Dderby.system.durability=test", - javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") + javaOptions in Test ++= System.getProperties.asScala.filter(_._1.startsWith("spark")) .map { case (k,v) => s"-D$k=$v" }.toSeq, javaOptions in Test += "-ea", javaOptions in Test ++= "-Xmx3g -Xss4096k -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g" diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 8af8637cf948d..0948f9b27cd38 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -61,6 +61,18 @@ def _to_seq(sc, cols, converter=None): return sc._jvm.PythonUtils.toSeq(cols) +def _to_list(sc, cols, converter=None): + """ + Convert a list of Column (or names) into a JVM (Scala) List of Column. + + An optional `converter` could be used to convert items in `cols` + into JVM Column objects. + """ + if converter: + cols = [converter(c) for c in cols] + return sc._jvm.PythonUtils.toList(cols) + + def _unary_op(name, doc="unary operator"): """ Create a method for given unary operator """ def _(self): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 025811f519293..e269ef4304f3f 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -32,7 +32,7 @@ from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql import since from pyspark.sql.types import _parse_datatype_json_string -from pyspark.sql.column import Column, _to_seq, _to_java_column +from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column from pyspark.sql.readwriter import DataFrameWriter from pyspark.sql.types import * @@ -494,7 +494,7 @@ def randomSplit(self, weights, seed=None): if w < 0.0: raise ValueError("Weights must be positive. Found weight value: %s" % w) seed = seed if seed is not None else random.randint(0, sys.maxsize) - rdd_array = self._jdf.randomSplit(_to_seq(self.sql_ctx._sc, weights), long(seed)) + rdd_array = self._jdf.randomSplit(_to_list(self.sql_ctx._sc, weights), long(seed)) return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array] @property diff --git a/scalastyle-config.xml b/scalastyle-config.xml index b5e2e882d2254..68fdb4141cf27 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -161,6 +161,13 @@ This file is divided into 3 sections: ]]> + + + JavaConversions + Instead of importing implicits in scala.collection.JavaConversions._, import + scala.collection.JavaConverters._ and use .asScala / .asJava methods + + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index ec895af9c3037..cfd9cb0e62598 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import scala.collection.JavaConverters._ + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types.StructType @@ -280,9 +282,8 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ - def getList[T](i: Int): java.util.List[T] = { - scala.collection.JavaConversions.seqAsJavaList(getSeq[T](i)) - } + def getList[T](i: Int): java.util.List[T] = + getSeq[T](i).asJava /** * Returns the value at position i of map type as a Scala Map. @@ -296,9 +297,8 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ - def getJavaMap[K, V](i: Int): java.util.Map[K, V] = { - scala.collection.JavaConversions.mapAsJavaMap(getMap[K, V](i)) - } + def getJavaMap[K, V](i: Int): java.util.Map[K, V] = + getMap[K, V](i).asJava /** * Returns the value at position i of struct type as an [[Row]] object. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 503c4f4b20f38..4cc9a5520a085 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import java.util.concurrent.ConcurrentHashMap -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -147,7 +147,7 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { val result = ArrayBuffer.empty[(String, Boolean)] - for (name <- tables.keySet()) { + for (name <- tables.keySet().asScala) { result += ((name, true)) } result diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index a4fd4cf3b330b..77a42c0873a6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import java.{lang => jl} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.expressions._ @@ -209,7 +209,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * @since 1.3.1 */ - def fill(valueMap: java.util.Map[String, Any]): DataFrame = fill0(valueMap.toSeq) + def fill(valueMap: java.util.Map[String, Any]): DataFrame = fill0(valueMap.asScala.toSeq) /** * (Scala-specific) Returns a new [[DataFrame]] that replaces null values. @@ -254,7 +254,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * @since 1.3.1 */ def replace[T](col: String, replacement: java.util.Map[T, T]): DataFrame = { - replace[T](col, replacement.toMap : Map[T, T]) + replace[T](col, replacement.asScala.toMap) } /** @@ -277,7 +277,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * @since 1.3.1 */ def replace[T](cols: Array[String], replacement: java.util.Map[T, T]): DataFrame = { - replace(cols.toSeq, replacement.toMap) + replace(cols.toSeq, replacement.asScala.toMap) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 6dc7bfe333498..97a8b6518a832 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import java.util.Properties +import scala.collection.JavaConverters._ + import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Experimental @@ -90,7 +92,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * @since 1.4.0 */ def options(options: java.util.Map[String, String]): DataFrameReader = { - this.options(scala.collection.JavaConversions.mapAsScalaMap(options)) + this.options(options.asScala) this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index ce8744b53175b..b2a66dd417b4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import java.util.Properties +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation @@ -109,7 +111,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def options(options: java.util.Map[String, String]): DataFrameWriter = { - this.options(scala.collection.JavaConversions.mapAsScalaMap(options)) + this.options(options.asScala) this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 99d557b03a033..ee31d83cce42c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental @@ -188,7 +188,7 @@ class GroupedData protected[sql]( * @since 1.3.0 */ def agg(exprs: java.util.Map[String, String]): DataFrame = { - agg(exprs.toMap) + agg(exprs.asScala.toMap) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index e9de14f025502..e6f7619519e6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import java.util.Properties import scala.collection.immutable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.parquet.hadoop.ParquetOutputCommitter @@ -531,7 +531,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf { /** Set Spark SQL configuration properties. */ def setConf(props: Properties): Unit = settings.synchronized { - props.foreach { case (k, v) => setConfString(k, v) } + props.asScala.foreach { case (k, v) => setConfString(k, v) } } /** Set the given Spark SQL configuration property using a `string` value. */ @@ -601,24 +601,25 @@ private[sql] class SQLConf extends Serializable with CatalystConf { * Return all the configuration properties that have been set (i.e. not the default). * This creates a new copy of the config properties in the form of a Map. */ - def getAllConfs: immutable.Map[String, String] = settings.synchronized { settings.toMap } + def getAllConfs: immutable.Map[String, String] = + settings.synchronized { settings.asScala.toMap } /** * Return all the configuration definitions that have been defined in [[SQLConf]]. Each * definition contains key, defaultValue and doc. */ def getAllDefinedConfs: Seq[(String, String, String)] = sqlConfEntries.synchronized { - sqlConfEntries.values.filter(_.isPublic).map { entry => + sqlConfEntries.values.asScala.filter(_.isPublic).map { entry => (entry.key, entry.defaultValueString, entry.doc) }.toSeq } private[spark] def unsetConf(key: String): Unit = { - settings -= key + settings.remove(key) } private[spark] def unsetConf(entry: SQLConfEntry[_]): Unit = { - settings -= entry.key + settings.remove(entry.key) } private[spark] def clear(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a1eea09e0477b..4e8414af50b44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -21,7 +21,7 @@ import java.beans.Introspector import java.util.Properties import java.util.concurrent.atomic.AtomicReference -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -225,7 +225,7 @@ class SQLContext(@transient val sparkContext: SparkContext) conf.setConf(properties) // After we have populated SQLConf, we call setConf to populate other confs in the subclass // (e.g. hiveconf in HiveContext). - properties.foreach { + properties.asScala.foreach { case (key, value) => setConf(key, value) } } @@ -567,7 +567,7 @@ class SQLContext(@transient val sparkContext: SparkContext) tableName: String, source: String, options: java.util.Map[String, String]): DataFrame = { - createExternalTable(tableName, source, options.toMap) + createExternalTable(tableName, source, options.asScala.toMap) } /** @@ -612,7 +612,7 @@ class SQLContext(@transient val sparkContext: SparkContext) source: String, schema: StructType, options: java.util.Map[String, String]): DataFrame = { - createExternalTable(tableName, source, schema, options.toMap) + createExternalTable(tableName, source, schema, options.asScala.toMap) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index 8fbaf3a3059db..011724436621d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import java.util.ServiceLoader -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.{existentials, implicitConversions} import scala.util.{Success, Failure, Try} @@ -55,7 +55,7 @@ object ResolvedDataSource extends Logging { val loader = Utils.getContextOrSparkClassLoader val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader) - serviceLoader.iterator().filter(_.shortName().equalsIgnoreCase(provider)).toList match { + serviceLoader.asScala.filter(_.shortName().equalsIgnoreCase(provider)).toList match { /** the provider format did not match any given registered aliases */ case Nil => Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { case Success(dataSource) => dataSource diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala index 3f8353af6e2ad..0a6bb44445f6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.util.{Map => JMap} -import scala.collection.JavaConversions.{iterableAsScalaIterable, mapAsJavaMap, mapAsScalaMap} +import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.parquet.hadoop.api.ReadSupport.ReadContext @@ -44,7 +44,7 @@ private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with val parquetRequestedSchema = readContext.getRequestedSchema val catalystRequestedSchema = - Option(readContext.getReadSupportMetadata).map(_.toMap).flatMap { metadata => + Option(readContext.getReadSupportMetadata).map(_.asScala).flatMap { metadata => metadata // First tries to read requested schema, which may result from projections .get(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA) @@ -123,7 +123,7 @@ private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with maybeRequestedSchema.fold(context.getFileSchema) { schemaString => val toParquet = new CatalystSchemaConverter(conf) val fileSchema = context.getFileSchema.asGroupType() - val fileFieldNames = fileSchema.getFields.map(_.getName).toSet + val fileFieldNames = fileSchema.getFields.asScala.map(_.getName).toSet StructType // Deserializes the Catalyst schema of requested columns @@ -152,7 +152,7 @@ private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with maybeRequestedSchema.map(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA -> _) ++ maybeRowSchema.map(RowWriteSupport.SPARK_ROW_SCHEMA -> _) - new ReadContext(parquetRequestedSchema, metadata) + new ReadContext(parquetRequestedSchema, metadata.asJava) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index cbf0704c4a9a4..f682ca0d8ff4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.math.{BigDecimal, BigInteger} import java.nio.ByteOrder -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.parquet.column.Dictionary @@ -183,7 +183,7 @@ private[parquet] class CatalystRowConverter( // those missing fields and create converters for them, although values of these fields are // always null. val paddedParquetFields = { - val parquetFields = parquetType.getFields + val parquetFields = parquetType.getFields.asScala val parquetFieldNames = parquetFields.map(_.getName).toSet val missingFields = catalystType.filterNot(f => parquetFieldNames.contains(f.name)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index 535f0684e97f9..be6c0545f5a0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.parquet.schema.OriginalType._ @@ -82,7 +82,7 @@ private[parquet] class CatalystSchemaConverter( def convert(parquetSchema: MessageType): StructType = convert(parquetSchema.asGroupType()) private def convert(parquetSchema: GroupType): StructType = { - val fields = parquetSchema.getFields.map { field => + val fields = parquetSchema.getFields.asScala.map { field => field.getRepetition match { case OPTIONAL => StructField(field.getName, convertField(field), nullable = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index bbf682aec0f9d..64982f37cf872 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -21,7 +21,7 @@ import java.net.URI import java.util.logging.{Logger => JLogger} import java.util.{List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.{Failure, Try} @@ -336,7 +336,7 @@ private[sql] class ParquetRelation( override def getPartitions: Array[SparkPartition] = { val inputFormat = new ParquetInputFormat[InternalRow] { override def listStatus(jobContext: JobContext): JList[FileStatus] = { - if (cacheMetadata) cachedStatuses else super.listStatus(jobContext) + if (cacheMetadata) cachedStatuses.asJava else super.listStatus(jobContext) } } @@ -344,7 +344,8 @@ private[sql] class ParquetRelation( val rawSplits = inputFormat.getSplits(jobContext) Array.tabulate[SparkPartition](rawSplits.size) { i => - new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + new SqlNewHadoopPartition( + id, i, rawSplits.get(i).asInstanceOf[InputSplit with Writable]) } } }.asInstanceOf[RDD[Row]] // type erasure hack to pass RDD[InternalRow] as RDD[Row] @@ -588,7 +589,7 @@ private[sql] object ParquetRelation extends Logging { val metadata = footer.getParquetMetadata.getFileMetaData val serializedSchema = metadata .getKeyValueMetaData - .toMap + .asScala.toMap .get(CatalystReadSupport.SPARK_METADATA_KEY) if (serializedSchema.isEmpty) { // Falls back to Parquet schema if no Spark SQL schema found. @@ -745,7 +746,7 @@ private[sql] object ParquetRelation extends Logging { // Reads footers in multi-threaded manner within each task val footers = ParquetFileReader.readAllFootersInParallel( - serializedConf.value, fakeFileStatuses, skipRowGroups) + serializedConf.value, fakeFileStatuses.asJava, skipRowGroups).asScala // Converter used to convert Parquet `MessageType` to Spark SQL `StructType` val converter = @@ -772,7 +773,7 @@ private[sql] object ParquetRelation extends Logging { val fileMetaData = footer.getParquetMetadata.getFileMetaData fileMetaData .getKeyValueMetaData - .toMap + .asScala.toMap .get(CatalystReadSupport.SPARK_METADATA_KEY) .flatMap(deserializeSchemaString) .getOrElse(converter.convert(fileMetaData.getSchema)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala index 42376ef7a9c1f..142301fe87cb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.IOException +import java.util.{Collections, Arrays} -import scala.collection.JavaConversions._ import scala.util.Try import org.apache.hadoop.conf.Configuration @@ -107,7 +107,7 @@ private[parquet] object ParquetTypesConverter extends Logging { ParquetFileWriter.writeMetadataFile( conf, path, - new Footer(path, new ParquetMetadata(metaData, Nil)) :: Nil) + Arrays.asList(new Footer(path, new ParquetMetadata(metaData, Collections.emptyList())))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index ed282f98b7d71..d800c7456bdac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.joins -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD @@ -92,9 +92,9 @@ case class ShuffledHashOuterJoin( case FullOuter => // TODO(davies): use UnsafeRow val leftHashTable = - buildHashTable(leftIter, numLeftRows, newProjection(leftKeys, left.output)) + buildHashTable(leftIter, numLeftRows, newProjection(leftKeys, left.output)).asScala val rightHashTable = - buildHashTable(rightIter, numRightRows, newProjection(rightKeys, right.output)) + buildHashTable(rightIter, numRightRows, newProjection(rightKeys, right.output)).asScala (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => fullOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 59f8b079ab333..5a58d846ad80b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import java.io.OutputStream import java.util.{List => JList, Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import net.razorvine.pickle._ @@ -196,14 +196,15 @@ object EvaluatePython { case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c case (c: java.util.List[_], ArrayType(elementType, _)) => - new GenericArrayData(c.map { e => fromJava(e, elementType)}.toArray) + new GenericArrayData(c.asScala.map { e => fromJava(e, elementType)}.toArray) case (c, ArrayType(elementType, _)) if c.getClass.isArray => new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType))) case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => - val keys = c.keysIterator.map(fromJava(_, keyType)).toArray - val values = c.valuesIterator.map(fromJava(_, valueType)).toArray + val keyValues = c.asScala.toSeq + val keys = keyValues.map(kv => fromJava(kv._1, keyType)).toArray + val values = keyValues.map(kv => fromJava(kv._2, valueType)).toArray ArrayBasedMapData(keys, values) case (c, StructType(fields)) if c.getClass.isArray => @@ -367,7 +368,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: val pickle = new Unpickler iter.flatMap { pickedResult => val unpickledBatch = pickle.loads(pickedResult) - unpickledBatch.asInstanceOf[java.util.ArrayList[Any]] + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala } }.mapPartitions { iter => val row = new GenericMutableRow(1) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 7abdd3db80341..4867cebf5328c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -23,7 +23,7 @@ import java.util.List; import java.util.Map; -import scala.collection.JavaConversions; +import scala.collection.JavaConverters; import scala.collection.Seq; import com.google.common.collect.ImmutableMap; @@ -96,7 +96,7 @@ public void testVarargMethods() { df.groupBy().agg(countDistinct("key", "value")); df.groupBy().agg(countDistinct(col("key"), col("value"))); df.select(coalesce(col("key"))); - + // Varargs with mathfunctions DataFrame df2 = context.table("testData2"); df2.select(exp("a"), exp("b")); @@ -172,7 +172,7 @@ public void testCreateDataFrameFromJavaBeans() { Seq outputBuffer = (Seq) first.getJavaMap(2).get("hello"); Assert.assertArrayEquals( bean.getC().get("hello"), - Ints.toArray(JavaConversions.seqAsJavaList(outputBuffer))); + Ints.toArray(JavaConverters.seqAsJavaListConverter(outputBuffer).asJava())); Seq d = first.getAs(3); Assert.assertEquals(bean.getD().size(), d.length()); for (int i = 0; i < d.length(); i++) { @@ -206,7 +206,7 @@ public void testCrosstab() { count++; } } - + @Test public void testFrequentItems() { DataFrame df = context.table("testData2"); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index cdaa14ac80785..329ffb66083b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.sql.test.SharedSQLContext @@ -153,11 +153,11 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { // Test Java version checkAnswer( - df.na.fill(mapAsJavaMap(Map( + df.na.fill(Map( "a" -> "test", "c" -> 1, "d" -> 2.2 - ))), + ).asJava), Row("test", null, 1, 2.2)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 4adcefb7dc4b1..3649c2a97b5ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import java.util.{Locale, TimeZone} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ @@ -145,7 +145,7 @@ object QueryTest { } def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): String = { - checkAnswer(df, expectedAnswer.toSeq) match { + checkAnswer(df, expectedAnswer.asScala) match { case Some(errorMessage) => errorMessage case None => null } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala index 45db619567a22..bd7cf8c10abef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala @@ -20,8 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.nio.ByteBuffer import java.util.{List => JList, Map => JMap} -import scala.collection.JavaConverters.seqAsJavaListConverter -import scala.collection.JavaConverters.mapAsJavaMapConverter +import scala.collection.JavaConverters._ import org.apache.avro.Schema import org.apache.avro.generic.IndexedRecord diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala index d85c564e3e8d1..df68432faeeb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.parquet.hadoop.ParquetFileReader @@ -40,8 +40,9 @@ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with Parq override def accept(path: Path): Boolean = pathFilter(path) }).toSeq - val footers = ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles, true) - footers.head.getParquetMetadata.getFileMetaData.getSchema + val footers = + ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles.asJava, true) + footers.iterator().next().getParquetMetadata.getFileMetaData.getSchema } protected def logParquetSchema(path: String): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index e6b0a2ea95e38..08d2b9dee99b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.execution.datasources.parquet -import scala.collection.JavaConversions._ +import java.util.Collections + +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -28,7 +30,7 @@ import org.apache.parquet.example.data.simple.SimpleGroup import org.apache.parquet.example.data.{Group, GroupWriter} import org.apache.parquet.hadoop.api.WriteSupport import org.apache.parquet.hadoop.api.WriteSupport.WriteContext -import org.apache.parquet.hadoop.metadata.{CompressionCodecName, FileMetaData, ParquetMetadata} +import org.apache.parquet.hadoop.metadata.{BlockMetaData, CompressionCodecName, FileMetaData, ParquetMetadata} import org.apache.parquet.hadoop.{Footer, ParquetFileWriter, ParquetOutputCommitter, ParquetWriter} import org.apache.parquet.io.api.RecordConsumer import org.apache.parquet.schema.{MessageType, MessageTypeParser} @@ -205,9 +207,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { test("compression codec") { def compressionCodecFor(path: String): String = { val codecs = ParquetTypesConverter - .readMetaData(new Path(path), Some(configuration)) - .getBlocks - .flatMap(_.getColumns) + .readMetaData(new Path(path), Some(configuration)).getBlocks.asScala + .flatMap(_.getColumns.asScala) .map(_.getCodec.name()) .distinct @@ -348,14 +349,16 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { """.stripMargin) withTempPath { location => - val extraMetadata = Map(CatalystReadSupport.SPARK_METADATA_KEY -> sparkSchema.toString) + val extraMetadata = Collections.singletonMap( + CatalystReadSupport.SPARK_METADATA_KEY, sparkSchema.toString) val fileMetadata = new FileMetaData(parquetSchema, extraMetadata, "Spark") val path = new Path(location.getCanonicalPath) ParquetFileWriter.writeMetadataFile( sqlContext.sparkContext.hadoopConfiguration, path, - new Footer(path, new ParquetMetadata(fileMetadata, Nil)) :: Nil) + Collections.singletonList( + new Footer(path, new ParquetMetadata(fileMetadata, Collections.emptyList())))) assertResult(sqlContext.read.parquet(path.toString).schema) { StructType( @@ -386,7 +389,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } finally { // Hadoop 1 doesn't have `Configuration.unset` configuration.clear() - clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) } } @@ -410,7 +413,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } finally { // Hadoop 1 doesn't have `Configuration.unset` configuration.clear() - clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) } } @@ -434,7 +437,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } finally { // Hadoop 1 doesn't have `Configuration.unset` configuration.clear() - clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) } } } @@ -481,7 +484,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } finally { // Hadoop 1 doesn't have `Configuration.unset` configuration.clear() - clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) } } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 02cc7e5efa521..306f98bcb5344 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.hive.thriftserver import java.security.PrivilegedExceptionAction import java.sql.{Date, Timestamp} import java.util.concurrent.RejectedExecutionException -import java.util.{Map => JMap, UUID} +import java.util.{Arrays, Map => JMap, UUID} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => SMap} import scala.util.control.NonFatal @@ -126,13 +126,13 @@ private[hive] class SparkExecuteStatementOperation( def getResultSetSchema: TableSchema = { if (result == null || result.queryExecution.analyzed.output.size == 0) { - new TableSchema(new FieldSchema("Result", "string", "") :: Nil) + new TableSchema(Arrays.asList(new FieldSchema("Result", "string", ""))) } else { logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") val schema = result.queryExecution.analyzed.output.map { attr => new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") } - new TableSchema(schema) + new TableSchema(schema.asJava) } } @@ -298,7 +298,7 @@ private[hive] class SparkExecuteStatementOperation( sqlOperationConf = new HiveConf(sqlOperationConf) // apply overlay query specific settings, if any - getConfOverlay().foreach { case (k, v) => + getConfOverlay().asScala.foreach { case (k, v) => try { sqlOperationConf.verifyAndSet(k, v) } catch { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 7799704c819d9..a29df567983b1 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.hive.thriftserver -import scala.collection.JavaConversions._ - import java.io._ import java.util.{ArrayList => JArrayList, Locale} +import scala.collection.JavaConverters._ + import jline.console.ConsoleReader import jline.console.history.FileHistory @@ -101,9 +101,9 @@ private[hive] object SparkSQLCLIDriver extends Logging { // Set all properties specified via command line. val conf: HiveConf = sessionState.getConf - sessionState.cmdProperties.entrySet().foreach { item => - val key = item.getKey.asInstanceOf[String] - val value = item.getValue.asInstanceOf[String] + sessionState.cmdProperties.entrySet().asScala.foreach { item => + val key = item.getKey.toString + val value = item.getValue.toString // We do not propagate metastore options to the execution copy of hive. if (key != "javax.jdo.option.ConnectionURL") { conf.set(key, value) @@ -316,15 +316,15 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { if (HiveConf.getBoolVar(conf, HiveConf.ConfVars.HIVE_CLI_PRINT_HEADER)) { // Print the column names. - Option(driver.getSchema.getFieldSchemas).map { fields => - out.println(fields.map(_.getName).mkString("\t")) + Option(driver.getSchema.getFieldSchemas).foreach { fields => + out.println(fields.asScala.map(_.getName).mkString("\t")) } } var counter = 0 try { while (!out.checkError() && driver.getResults(res)) { - res.foreach{ l => + res.asScala.foreach { l => counter += 1 out.println(l) } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index 644165acf70a7..5ad8c54f296d5 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -21,6 +21,8 @@ import java.io.IOException import java.util.{List => JList} import javax.security.auth.login.LoginException +import scala.collection.JavaConverters._ + import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.shims.Utils @@ -34,8 +36,6 @@ import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ -import scala.collection.JavaConversions._ - private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, hiveContext: HiveContext) extends CLIService(hiveServer) with ReflectedCompositeService { @@ -76,7 +76,7 @@ private[thriftserver] trait ReflectedCompositeService { this: AbstractService => def initCompositeService(hiveConf: HiveConf) { // Emulating `CompositeService.init(hiveConf)` val serviceList = getAncestorField[JList[Service]](this, 2, "serviceList") - serviceList.foreach(_.init(hiveConf)) + serviceList.asScala.foreach(_.init(hiveConf)) // Emulating `AbstractService.init(hiveConf)` invoke(classOf[AbstractService], this, "ensureCurrentState", classOf[STATE] -> STATE.NOTINITED) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index 77272aecf2835..2619286afc148 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.hive.thriftserver -import java.util.{ArrayList => JArrayList, List => JList} +import java.util.{Arrays, ArrayList => JArrayList, List => JList} + +import scala.collection.JavaConverters._ import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} @@ -27,8 +29,6 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse import org.apache.spark.Logging import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -import scala.collection.JavaConversions._ - private[hive] class SparkSQLDriver( val context: HiveContext = SparkSQLEnv.hiveContext) extends Driver @@ -43,14 +43,14 @@ private[hive] class SparkSQLDriver( private def getResultSetSchema(query: context.QueryExecution): Schema = { val analyzed = query.analyzed logDebug(s"Result Schema: ${analyzed.output}") - if (analyzed.output.size == 0) { - new Schema(new FieldSchema("Response code", "string", "") :: Nil, null) + if (analyzed.output.isEmpty) { + new Schema(Arrays.asList(new FieldSchema("Response code", "string", "")), null) } else { val fieldSchemas = analyzed.output.map { attr => new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") } - new Schema(fieldSchemas, null) + new Schema(fieldSchemas.asJava, null) } } @@ -79,7 +79,7 @@ private[hive] class SparkSQLDriver( if (hiveResponse == null) { false } else { - res.asInstanceOf[JArrayList[String]].addAll(hiveResponse) + res.asInstanceOf[JArrayList[String]].addAll(hiveResponse.asJava) hiveResponse = null true } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 1d41c46131828..bacf6cc458fd5 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.thriftserver import java.io.PrintStream -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.scheduler.StatsReportListener import org.apache.spark.sql.hive.HiveContext @@ -64,7 +64,7 @@ private[hive] object SparkSQLEnv extends Logging { hiveContext.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) if (log.isDebugEnabled) { - hiveContext.hiveconf.getAllProperties.toSeq.sorted.foreach { case (k, v) => + hiveContext.hiveconf.getAllProperties.asScala.toSeq.sorted.foreach { case (k, v) => logDebug(s"HiveConf var: $k=$v") } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 17cc83087fb1d..c0a458fa9ab8d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -22,7 +22,7 @@ import java.net.{URL, URLClassLoader} import java.sql.Timestamp import java.util.concurrent.TimeUnit -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.HashMap import scala.language.implicitConversions import scala.concurrent.duration._ @@ -194,7 +194,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { logInfo("defalt warehouse location is " + defaltWarehouseLocation) // `configure` goes second to override other settings. - val allConfig = metadataConf.iterator.map(e => e.getKey -> e.getValue).toMap ++ configure + val allConfig = metadataConf.asScala.map(e => e.getKey -> e.getValue).toMap ++ configure val isolatedLoader = if (hiveMetastoreJars == "builtin") { if (hiveExecutionVersion != hiveMetastoreVersion) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 64fffdbf9b020..cfe2bb05ad89e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import scala.collection.JavaConverters._ + import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _} @@ -31,9 +33,6 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, types} import org.apache.spark.unsafe.types.UTF8String -/* Implicit conversions */ -import scala.collection.JavaConversions._ - /** * 1. The Underlying data type in catalyst and in Hive * In catalyst: @@ -290,13 +289,13 @@ private[hive] trait HiveInspectors { DateTimeUtils.fromJavaDate(poi.getWritableConstantValue.get()) case mi: StandardConstantMapObjectInspector => // take the value from the map inspector object, rather than the input data - val map = mi.getWritableConstantValue - val keys = map.keysIterator.map(unwrap(_, mi.getMapKeyObjectInspector)).toArray - val values = map.valuesIterator.map(unwrap(_, mi.getMapValueObjectInspector)).toArray + val keyValues = mi.getWritableConstantValue.asScala.toSeq + val keys = keyValues.map(kv => unwrap(kv._1, mi.getMapKeyObjectInspector)).toArray + val values = keyValues.map(kv => unwrap(kv._2, mi.getMapValueObjectInspector)).toArray ArrayBasedMapData(keys, values) case li: StandardConstantListObjectInspector => // take the value from the list inspector object, rather than the input data - val values = li.getWritableConstantValue + val values = li.getWritableConstantValue.asScala .map(unwrap(_, li.getListElementObjectInspector)) .toArray new GenericArrayData(values) @@ -342,7 +341,7 @@ private[hive] trait HiveInspectors { case li: ListObjectInspector => Option(li.getList(data)) .map { l => - val values = l.map(unwrap(_, li.getListElementObjectInspector)).toArray + val values = l.asScala.map(unwrap(_, li.getListElementObjectInspector)).toArray new GenericArrayData(values) } .orNull @@ -351,15 +350,16 @@ private[hive] trait HiveInspectors { if (map == null) { null } else { - val keys = map.keysIterator.map(unwrap(_, mi.getMapKeyObjectInspector)).toArray - val values = map.valuesIterator.map(unwrap(_, mi.getMapValueObjectInspector)).toArray + val keyValues = map.asScala.toSeq + val keys = keyValues.map(kv => unwrap(kv._1, mi.getMapKeyObjectInspector)).toArray + val values = keyValues.map(kv => unwrap(kv._2, mi.getMapValueObjectInspector)).toArray ArrayBasedMapData(keys, values) } // currently, hive doesn't provide the ConstantStructObjectInspector case si: StructObjectInspector => val allRefs = si.getAllStructFieldRefs - InternalRow.fromSeq( - allRefs.map(r => unwrap(si.getStructFieldData(data, r), r.getFieldObjectInspector))) + InternalRow.fromSeq(allRefs.asScala.map( + r => unwrap(si.getStructFieldData(data, r), r.getFieldObjectInspector))) } @@ -403,14 +403,14 @@ private[hive] trait HiveInspectors { case soi: StandardStructObjectInspector => val schema = dataType.asInstanceOf[StructType] - val wrappers = soi.getAllStructFieldRefs.zip(schema.fields).map { case (ref, field) => - wrapperFor(ref.getFieldObjectInspector, field.dataType) + val wrappers = soi.getAllStructFieldRefs.asScala.zip(schema.fields).map { + case (ref, field) => wrapperFor(ref.getFieldObjectInspector, field.dataType) } (o: Any) => { if (o != null) { val struct = soi.create() val row = o.asInstanceOf[InternalRow] - soi.getAllStructFieldRefs.zip(wrappers).zipWithIndex.foreach { + soi.getAllStructFieldRefs.asScala.zip(wrappers).zipWithIndex.foreach { case ((field, wrapper), i) => soi.setStructFieldData(struct, field, wrapper(row.get(i, schema(i).dataType))) } @@ -537,7 +537,7 @@ private[hive] trait HiveInspectors { // 1. create the pojo (most likely) object val result = x.create() var i = 0 - while (i < fieldRefs.length) { + while (i < fieldRefs.size) { // 2. set the property for the pojo val tpe = structType(i).dataType x.setStructFieldData( @@ -552,9 +552,9 @@ private[hive] trait HiveInspectors { val fieldRefs = x.getAllStructFieldRefs val structType = dataType.asInstanceOf[StructType] val row = a.asInstanceOf[InternalRow] - val result = new java.util.ArrayList[AnyRef](fieldRefs.length) + val result = new java.util.ArrayList[AnyRef](fieldRefs.size) var i = 0 - while (i < fieldRefs.length) { + while (i < fieldRefs.size) { val tpe = structType(i).dataType result.add(wrap(row.get(i, tpe), fieldRefs.get(i).getFieldObjectInspector, tpe)) i += 1 @@ -712,10 +712,10 @@ private[hive] trait HiveInspectors { def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match { case s: StructObjectInspector => - StructType(s.getAllStructFieldRefs.map(f => { + StructType(s.getAllStructFieldRefs.asScala.map(f => types.StructField( f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), nullable = true) - })) + )) case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector)) case m: MapObjectInspector => MapType( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 98d21aa76d64e..b8da0840ae569 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable import com.google.common.base.Objects @@ -483,7 +483,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // are empty. val partitions = metastoreRelation.getHiveQlPartitions().map { p => val location = p.getLocation - val values = InternalRow.fromSeq(p.getValues.zip(partitionColumnDataTypes).map { + val values = InternalRow.fromSeq(p.getValues.asScala.zip(partitionColumnDataTypes).map { case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null) }) ParquetPartition(values, location) @@ -798,9 +798,9 @@ private[hive] case class MetastoreRelation val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() tTable.setSd(sd) - sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) + sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) tTable.setPartitionKeys( - table.partitionColumns.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) + table.partitionColumns.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) table.location.foreach(sd.setLocation) table.inputFormat.foreach(sd.setInputFormat) @@ -852,11 +852,11 @@ private[hive] case class MetastoreRelation val tPartition = new org.apache.hadoop.hive.metastore.api.Partition tPartition.setDbName(databaseName) tPartition.setTableName(tableName) - tPartition.setValues(p.values) + tPartition.setValues(p.values.asJava) val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() tPartition.setSd(sd) - sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) + sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) sd.setLocation(p.storage.location) sd.setInputFormat(p.storage.inputFormat) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index ad33dee555dd2..d5cd7e98b5267 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -20,6 +20,9 @@ package org.apache.spark.sql.hive import java.sql.Date import java.util.Locale +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.serde.serdeConstants @@ -48,10 +51,6 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.random.RandomSampler -/* Implicit conversions */ -import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer - /** * Used when we need to start parsing the AST before deciding that we are going to pass the command * back for Hive to execute natively. Will be replaced with a native command that contains the @@ -202,7 +201,7 @@ private[hive] object HiveQl extends Logging { * Returns a scala.Seq equivalent to [s] or Nil if [s] is null. */ private def nilIfEmpty[A](s: java.util.List[A]): Seq[A] = - Option(s).map(_.toSeq).getOrElse(Nil) + Option(s).map(_.asScala).getOrElse(Nil) /** * Returns this ASTNode with the text changed to `newText`. @@ -217,7 +216,7 @@ private[hive] object HiveQl extends Logging { */ def withChildren(newChildren: Seq[ASTNode]): ASTNode = { (1 to n.getChildCount).foreach(_ => n.deleteChild(0)) - n.addChildren(newChildren) + n.addChildren(newChildren.asJava) n } @@ -323,11 +322,11 @@ private[hive] object HiveQl extends Logging { assert(tree.asInstanceOf[ASTNode].getText == "TOK_CREATETABLE", "Only CREATE TABLE supported.") val tableOps = tree.getChildren val colList = - tableOps + tableOps.asScala .find(_.asInstanceOf[ASTNode].getText == "TOK_TABCOLLIST") .getOrElse(sys.error("No columnList!")).getChildren - colList.map(nodeToAttribute) + colList.asScala.map(nodeToAttribute) } /** Extractor for matching Hive's AST Tokens. */ @@ -337,7 +336,7 @@ private[hive] object HiveQl extends Logging { case t: ASTNode => CurrentOrigin.setPosition(t.getLine, t.getCharPositionInLine) Some((t.getText, - Option(t.getChildren).map(_.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]])) + Option(t.getChildren).map(_.asScala.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]])) case _ => None } } @@ -424,7 +423,9 @@ private[hive] object HiveQl extends Logging { protected def extractDbNameTableName(tableNameParts: Node): (Option[String], String) = { val (db, tableName) = - tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match { + tableNameParts.getChildren.asScala.map { + case Token(part, Nil) => cleanIdentifier(part) + } match { case Seq(tableOnly) => (None, tableOnly) case Seq(databaseName, table) => (Some(databaseName), table) } @@ -433,7 +434,9 @@ private[hive] object HiveQl extends Logging { } protected def extractTableIdent(tableNameParts: Node): Seq[String] = { - tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match { + tableNameParts.getChildren.asScala.map { + case Token(part, Nil) => cleanIdentifier(part) + } match { case Seq(tableOnly) => Seq(tableOnly) case Seq(databaseName, table) => Seq(databaseName, table) case other => sys.error("Hive only supports tables names like 'tableName' " + @@ -624,7 +627,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val cols = BaseSemanticAnalyzer.getColumns(list, true) if (cols != null) { tableDesc = tableDesc.copy( - schema = cols.map { field => + schema = cols.asScala.map { field => HiveColumn(field.getName, field.getType, field.getComment) }) } @@ -636,7 +639,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val cols = BaseSemanticAnalyzer.getColumns(list(0), false) if (cols != null) { tableDesc = tableDesc.copy( - partitionColumns = cols.map { field => + partitionColumns = cols.asScala.map { field => HiveColumn(field.getName, field.getType, field.getComment) }) } @@ -672,7 +675,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case _ => assert(false) } tableDesc = tableDesc.copy( - serdeProperties = tableDesc.serdeProperties ++ serdeParams) + serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) case Token("TOK_TABLELOCATION", child :: Nil) => var location = BaseSemanticAnalyzer.unescapeSQLString(child.getText) location = EximUtil.relativeToAbsolutePath(hiveConf, location) @@ -684,7 +687,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val serdeParams = new java.util.HashMap[String, String]() BaseSemanticAnalyzer.readProps( (child.getChild(1).getChild(0)).asInstanceOf[ASTNode], serdeParams) - tableDesc = tableDesc.copy(serdeProperties = tableDesc.serdeProperties ++ serdeParams) + tableDesc = tableDesc.copy( + serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) } case Token("TOK_FILEFORMAT_GENERIC", child :: Nil) => child.getText().toLowerCase(Locale.ENGLISH) match { @@ -847,7 +851,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } val withWhere = whereClause.map { whereNode => - val Seq(whereExpr) = whereNode.getChildren.toSeq + val Seq(whereExpr) = whereNode.getChildren.asScala Filter(nodeToExpr(whereExpr), relations) }.getOrElse(relations) @@ -856,7 +860,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // Script transformations are expressed as a select clause with a single expression of type // TOK_TRANSFORM - val transformation = select.getChildren.head match { + val transformation = select.getChildren.iterator().next() match { case Token("TOK_SELEXPR", Token("TOK_TRANSFORM", Token("TOK_EXPLIST", inputExprs) :: @@ -925,10 +929,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val withLateralView = lateralViewClause.map { lv => val Token("TOK_SELECT", - Token("TOK_SELEXPR", clauses) :: Nil) = lv.getChildren.head + Token("TOK_SELEXPR", clauses) :: Nil) = lv.getChildren.iterator().next() - val alias = - getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText + val alias = getClause("TOK_TABALIAS", clauses).getChildren.iterator().next() + .asInstanceOf[ASTNode].getText val (generator, attributes) = nodesToGenerator(clauses) Generate( @@ -944,7 +948,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // (if there is a group by) or a script transformation. val withProject: LogicalPlan = transformation.getOrElse { val selectExpressions = - select.getChildren.flatMap(selExprNodeToExpr).map(UnresolvedAlias(_)).toSeq + select.getChildren.asScala.flatMap(selExprNodeToExpr).map(UnresolvedAlias) Seq( groupByClause.map(e => e match { case Token("TOK_GROUPBY", children) => @@ -973,7 +977,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // Handle HAVING clause. val withHaving = havingClause.map { h => - val havingExpr = h.getChildren.toSeq match { case Seq(hexpr) => nodeToExpr(hexpr) } + val havingExpr = h.getChildren.asScala match { case Seq(hexpr) => nodeToExpr(hexpr) } // Note that we added a cast to boolean. If the expression itself is already boolean, // the optimizer will get rid of the unnecessary cast. Filter(Cast(havingExpr, BooleanType), withProject) @@ -983,32 +987,42 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val withDistinct = if (selectDistinctClause.isDefined) Distinct(withHaving) else withHaving - // Handle ORDER BY, SORT BY, DISTRIBETU BY, and CLUSTER BY clause. + // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. val withSort = (orderByClause, sortByClause, distributeByClause, clusterByClause) match { case (Some(totalOrdering), None, None, None) => - Sort(totalOrdering.getChildren.map(nodeToSortOrder), true, withDistinct) + Sort(totalOrdering.getChildren.asScala.map(nodeToSortOrder), true, withDistinct) case (None, Some(perPartitionOrdering), None, None) => - Sort(perPartitionOrdering.getChildren.map(nodeToSortOrder), false, withDistinct) + Sort( + perPartitionOrdering.getChildren.asScala.map(nodeToSortOrder), + false, withDistinct) case (None, None, Some(partitionExprs), None) => - RepartitionByExpression(partitionExprs.getChildren.map(nodeToExpr), withDistinct) + RepartitionByExpression( + partitionExprs.getChildren.asScala.map(nodeToExpr), withDistinct) case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => - Sort(perPartitionOrdering.getChildren.map(nodeToSortOrder), false, - RepartitionByExpression(partitionExprs.getChildren.map(nodeToExpr), withDistinct)) + Sort( + perPartitionOrdering.getChildren.asScala.map(nodeToSortOrder), false, + RepartitionByExpression( + partitionExprs.getChildren.asScala.map(nodeToExpr), + withDistinct)) case (None, None, None, Some(clusterExprs)) => - Sort(clusterExprs.getChildren.map(nodeToExpr).map(SortOrder(_, Ascending)), false, - RepartitionByExpression(clusterExprs.getChildren.map(nodeToExpr), withDistinct)) + Sort( + clusterExprs.getChildren.asScala.map(nodeToExpr).map(SortOrder(_, Ascending)), + false, + RepartitionByExpression( + clusterExprs.getChildren.asScala.map(nodeToExpr), + withDistinct)) case (None, None, None, None) => withDistinct case _ => sys.error("Unsupported set of ordering / distribution clauses.") } val withLimit = - limitClause.map(l => nodeToExpr(l.getChildren.head)) + limitClause.map(l => nodeToExpr(l.getChildren.iterator().next())) .map(Limit(_, withSort)) .getOrElse(withSort) // Collect all window specifications defined in the WINDOW clause. - val windowDefinitions = windowClause.map(_.getChildren.toSeq.collect { + val windowDefinitions = windowClause.map(_.getChildren.asScala.collect { case Token("TOK_WINDOWDEF", Token(windowName, Nil) :: Token("TOK_WINDOWSPEC", spec) :: Nil) => windowName -> nodesToWindowSpecification(spec) @@ -1063,7 +1077,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val Token("TOK_SELECT", Token("TOK_SELEXPR", clauses) :: Nil) = selectClause - val alias = getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText + val alias = getClause("TOK_TABALIAS", clauses).getChildren.iterator().next() + .asInstanceOf[ASTNode].getText val (generator, attributes) = nodesToGenerator(clauses) Generate( @@ -1092,7 +1107,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } val tableIdent = - tableNameParts.getChildren.map{ case Token(part, Nil) => cleanIdentifier(part)} match { + tableNameParts.getChildren.asScala.map { + case Token(part, Nil) => cleanIdentifier(part) + } match { case Seq(tableOnly) => Seq(tableOnly) case Seq(databaseName, table) => Seq(databaseName, table) case other => sys.error("Hive only supports tables names like 'tableName' " + @@ -1139,7 +1156,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val isPreserved = tableOrdinals.map(i => (i - 1 < 0) || joinArgs(i - 1).getText == "PRESERVE") val tables = tableOrdinals.map(i => nodeToRelation(joinArgs(i))) - val joinExpressions = tableOrdinals.map(i => joinArgs(i + 1).getChildren.map(nodeToExpr)) + val joinExpressions = + tableOrdinals.map(i => joinArgs(i + 1).getChildren.asScala.map(nodeToExpr)) val joinConditions = joinExpressions.sliding(2).map { case Seq(c1, c2) => @@ -1164,7 +1182,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C joinType = joinType.remove(joinType.length - 1)) } - val groups = (0 until joinExpressions.head.size).map(i => Coalesce(joinExpressions.map(_(i)))) + val groups = joinExpressions.head.indices.map(i => Coalesce(joinExpressions.map(_(i)))) // Unique join is not really the same as an outer join so we must group together results where // the joinExpressions are the same, taking the First of each value is only okay because the @@ -1229,7 +1247,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val tableIdent = extractTableIdent(tableNameParts) - val partitionKeys = partitionClause.map(_.getChildren.map { + val partitionKeys = partitionClause.map(_.getChildren.asScala.map { // Parse partitions. We also make keys case insensitive. case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => cleanIdentifier(key.toLowerCase) -> Some(PlanUtils.stripQuotes(value)) @@ -1249,7 +1267,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val tableIdent = extractTableIdent(tableNameParts) - val partitionKeys = partitionClause.map(_.getChildren.map { + val partitionKeys = partitionClause.map(_.getChildren.asScala.map { // Parse partitions. We also make keys case insensitive. case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => cleanIdentifier(key.toLowerCase) -> Some(PlanUtils.stripQuotes(value)) @@ -1590,18 +1608,18 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val (partitionByClause :: orderByClause :: sortByClause :: clusterByClause :: Nil) = getClauses( Seq("TOK_DISTRIBUTEBY", "TOK_ORDERBY", "TOK_SORTBY", "TOK_CLUSTERBY"), - partitionAndOrdering.getChildren.toSeq.asInstanceOf[Seq[ASTNode]]) + partitionAndOrdering.getChildren.asScala.asInstanceOf[Seq[ASTNode]]) (partitionByClause, orderByClause.orElse(sortByClause), clusterByClause) match { case (Some(partitionByExpr), Some(orderByExpr), None) => - (partitionByExpr.getChildren.map(nodeToExpr), - orderByExpr.getChildren.map(nodeToSortOrder)) + (partitionByExpr.getChildren.asScala.map(nodeToExpr), + orderByExpr.getChildren.asScala.map(nodeToSortOrder)) case (Some(partitionByExpr), None, None) => - (partitionByExpr.getChildren.map(nodeToExpr), Nil) + (partitionByExpr.getChildren.asScala.map(nodeToExpr), Nil) case (None, Some(orderByExpr), None) => - (Nil, orderByExpr.getChildren.map(nodeToSortOrder)) + (Nil, orderByExpr.getChildren.asScala.map(nodeToSortOrder)) case (None, None, Some(clusterByExpr)) => - val expressions = clusterByExpr.getChildren.map(nodeToExpr) + val expressions = clusterByExpr.getChildren.asScala.map(nodeToExpr) (expressions, expressions.map(SortOrder(_, Ascending))) case _ => throw new NotImplementedError( @@ -1639,7 +1657,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } rowFrame.orElse(rangeFrame).map { frame => - frame.getChildren.toList match { + frame.getChildren.asScala.toList match { case precedingNode :: followingNode :: Nil => SpecifiedWindowFrame( frameType, @@ -1701,7 +1719,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case other => sys.error(s"Non ASTNode encountered: $other") } - Option(node.getChildren).map(_.toList).getOrElse(Nil).foreach(dumpTree(_, builder, indent + 1)) + Option(node.getChildren).map(_.asScala).getOrElse(Nil).foreach(dumpTree(_, builder, indent + 1)) builder } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index 267074f3ad102..004805f3aed0b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -22,8 +22,7 @@ import java.rmi.server.UID import org.apache.avro.Schema -/* Implicit conversions */ -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -73,7 +72,7 @@ private[hive] object HiveShim { */ def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { if (ids != null && ids.nonEmpty) { - ColumnProjectionUtils.appendReadColumns(conf, ids) + ColumnProjectionUtils.appendReadColumns(conf, ids.asJava) } if (names != null && names.nonEmpty) { appendReadColumnNames(conf, names) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index f49c97de8ff4e..4d1e3ed9198e6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -21,7 +21,7 @@ import java.io.{File, PrintStream} import java.util.{Map => JMap} import javax.annotation.concurrent.GuardedBy -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.reflectiveCalls import org.apache.hadoop.fs.Path @@ -305,10 +305,11 @@ private[hive] class ClientWrapper( HiveTable( name = h.getTableName, specifiedDatabase = Option(h.getDbName), - schema = h.getCols.map(f => HiveColumn(f.getName, f.getType, f.getComment)), - partitionColumns = h.getPartCols.map(f => HiveColumn(f.getName, f.getType, f.getComment)), - properties = h.getParameters.toMap, - serdeProperties = h.getTTable.getSd.getSerdeInfo.getParameters.toMap, + schema = h.getCols.asScala.map(f => HiveColumn(f.getName, f.getType, f.getComment)), + partitionColumns = h.getPartCols.asScala.map(f => + HiveColumn(f.getName, f.getType, f.getComment)), + properties = h.getParameters.asScala.toMap, + serdeProperties = h.getTTable.getSd.getSerdeInfo.getParameters.asScala.toMap, tableType = h.getTableType match { case HTableType.MANAGED_TABLE => ManagedTable case HTableType.EXTERNAL_TABLE => ExternalTable @@ -334,9 +335,9 @@ private[hive] class ClientWrapper( private def toQlTable(table: HiveTable): metadata.Table = { val qlTable = new metadata.Table(table.database, table.name) - qlTable.setFields(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) + qlTable.setFields(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) qlTable.setPartCols( - table.partitionColumns.map(c => new FieldSchema(c.name, c.hiveType, c.comment))) + table.partitionColumns.map(c => new FieldSchema(c.name, c.hiveType, c.comment)).asJava) table.properties.foreach { case (k, v) => qlTable.setProperty(k, v) } table.serdeProperties.foreach { case (k, v) => qlTable.setSerdeParam(k, v) } @@ -366,13 +367,13 @@ private[hive] class ClientWrapper( private def toHivePartition(partition: metadata.Partition): HivePartition = { val apiPartition = partition.getTPartition HivePartition( - values = Option(apiPartition.getValues).map(_.toSeq).getOrElse(Seq.empty), + values = Option(apiPartition.getValues).map(_.asScala).getOrElse(Seq.empty), storage = HiveStorageDescriptor( location = apiPartition.getSd.getLocation, inputFormat = apiPartition.getSd.getInputFormat, outputFormat = apiPartition.getSd.getOutputFormat, serde = apiPartition.getSd.getSerdeInfo.getSerializationLib, - serdeProperties = apiPartition.getSd.getSerdeInfo.getParameters.toMap)) + serdeProperties = apiPartition.getSd.getSerdeInfo.getParameters.asScala.toMap)) } override def getPartitionOption( @@ -397,7 +398,7 @@ private[hive] class ClientWrapper( } override def listTables(dbName: String): Seq[String] = withHiveState { - client.getAllTables(dbName) + client.getAllTables(dbName).asScala } /** @@ -514,17 +515,17 @@ private[hive] class ClientWrapper( } def reset(): Unit = withHiveState { - client.getAllTables("default").foreach { t => + client.getAllTables("default").asScala.foreach { t => logDebug(s"Deleting table $t") val table = client.getTable("default", t) - client.getIndexes("default", t, 255).foreach { index => + client.getIndexes("default", t, 255).asScala.foreach { index => shim.dropIndex(client, "default", t, index.getIndexName) } if (!table.isIndexTable) { client.dropTable("default", t) } } - client.getAllDatabases.filterNot(_ == "default").foreach { db => + client.getAllDatabases.asScala.filterNot(_ == "default").foreach { db => logDebug(s"Dropping Database: $db") client.dropDatabase(db, true, false, true) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 8fc8935b1dc3c..48bbb21e6c1de 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -23,7 +23,7 @@ import java.net.URI import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, Set => JSet} import java.util.concurrent.TimeUnit -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.conf.HiveConf @@ -201,7 +201,7 @@ private[client] class Shim_v0_12 extends Shim with Logging { setDataLocationMethod.invoke(table, new URI(loc)) override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] = - getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq + getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].asScala.toSeq override def getPartitionsByFilter( hive: Hive, @@ -220,7 +220,7 @@ private[client] class Shim_v0_12 extends Shim with Logging { override def getDriverResults(driver: Driver): Seq[String] = { val res = new JArrayList[String]() getDriverResultsMethod.invoke(driver, res) - res.toSeq + res.asScala } override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { @@ -310,7 +310,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { setDataLocationMethod.invoke(table, new Path(loc)) override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] = - getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq + getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].asScala.toSeq /** * Converts catalyst expression to the format that Hive's getPartitionsByFilter() expects, i.e. @@ -320,7 +320,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { */ def convertFilters(table: Table, filters: Seq[Expression]): String = { // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. - val varcharKeys = table.getPartitionKeys + val varcharKeys = table.getPartitionKeys.asScala .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME)) .map(col => col.getName).toSet @@ -354,7 +354,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { getPartitionsByFilterMethod.invoke(hive, table, filter).asInstanceOf[JArrayList[Partition]] } - partitions.toSeq + partitions.asScala.toSeq } override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor = @@ -363,7 +363,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { override def getDriverResults(driver: Driver): Seq[String] = { val res = new JArrayList[Object]() getDriverResultsMethod.invoke(driver, res) - res.map { r => + res.asScala.map { r => r match { case s: String => s case a: Array[Object] => a(0).asInstanceOf[String] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala index 5f0ed5393d191..441b6b6033e1f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.execution -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.hive.metastore.api.FieldSchema @@ -39,8 +39,8 @@ case class DescribeHiveTableCommand( // Trying to mimic the format of Hive's output. But not exactly the same. var results: Seq[(String, String, String)] = Nil - val columns: Seq[FieldSchema] = table.hiveQlTable.getCols - val partitionColumns: Seq[FieldSchema] = table.hiveQlTable.getPartCols + val columns: Seq[FieldSchema] = table.hiveQlTable.getCols.asScala + val partitionColumns: Seq[FieldSchema] = table.hiveQlTable.getPartCols.asScala results ++= columns.map(field => (field.getName, field.getType, field.getComment)) if (partitionColumns.nonEmpty) { val partColumnInfo = @@ -48,7 +48,7 @@ case class DescribeHiveTableCommand( results ++= partColumnInfo ++ Seq(("# Partition Information", "", "")) ++ - Seq((s"# ${output.get(0).name}", output.get(1).name, output.get(2).name)) ++ + Seq((s"# ${output(0).name}", output(1).name, output(2).name)) ++ partColumnInfo } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index ba7eb15a1c0c6..806d2b9b0b7d4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.execution -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition} @@ -98,7 +98,7 @@ case class HiveTableScan( .asInstanceOf[StructObjectInspector] val columnTypeNames = structOI - .getAllStructFieldRefs + .getAllStructFieldRefs.asScala .map(_.getFieldObjectInspector) .map(TypeInfoUtils.getTypeInfoFromObjectInspector(_).getTypeName) .mkString(",") @@ -118,9 +118,8 @@ case class HiveTableScan( case None => partitions case Some(shouldKeep) => partitions.filter { part => val dataTypes = relation.partitionKeys.map(_.dataType) - val castedValues = for ((value, dataType) <- part.getValues.zip(dataTypes)) yield { - castFromString(value, dataType) - } + val castedValues = part.getValues.asScala.zip(dataTypes) + .map { case (value, dataType) => castFromString(value, dataType) } // Only partitioned values are needed here, since the predicate has already been bound to // partition key attribute references. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 62efda613a176..58f7fa640e8a9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.hive.execution import java.util +import scala.collection.JavaConverters._ + import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.metastore.MetaStoreUtils import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde2.Serializer @@ -38,8 +39,6 @@ import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive._ import org.apache.spark.sql.types.DataType import org.apache.spark.{SparkException, TaskContext} - -import scala.collection.JavaConversions._ import org.apache.spark.util.SerializableJobConf private[hive] @@ -94,7 +93,8 @@ case class InsertIntoHiveTable( ObjectInspectorCopyOption.JAVA) .asInstanceOf[StructObjectInspector] - val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray + val fieldOIs = standardOI.getAllStructFieldRefs.asScala + .map(_.getFieldObjectInspector).toArray val dataTypes: Array[DataType] = child.output.map(_.dataType).toArray val wrappers = fieldOIs.zip(dataTypes).map { case (f, dt) => wrapperFor(f, dt)} val outputData = new Array[Any](fieldOIs.length) @@ -198,7 +198,7 @@ case class InsertIntoHiveTable( // loadPartition call orders directories created on the iteration order of the this map val orderedPartitionSpec = new util.LinkedHashMap[String, String]() - table.hiveQlTable.getPartCols().foreach { entry => + table.hiveQlTable.getPartCols.asScala.foreach { entry => orderedPartitionSpec.put(entry.getName, partitionSpec.get(entry.getName).getOrElse("")) } @@ -226,7 +226,7 @@ case class InsertIntoHiveTable( val oldPart = catalog.client.getPartitionOption( catalog.client.getTable(table.databaseName, table.tableName), - partitionSpec) + partitionSpec.asJava) if (oldPart.isEmpty || !ifNotExists) { catalog.client.loadPartition( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index ade27454b9d29..c7651daffe36e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -21,7 +21,7 @@ import java.io._ import java.util.Properties import javax.annotation.Nullable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.control.NonFatal import org.apache.hadoop.hive.serde.serdeConstants @@ -61,7 +61,7 @@ case class ScriptTransformation( protected override def doExecute(): RDD[InternalRow] = { def processIterator(inputIterator: Iterator[InternalRow]): Iterator[InternalRow] = { val cmd = List("/bin/bash", "-c", script) - val builder = new ProcessBuilder(cmd) + val builder = new ProcessBuilder(cmd.asJava) val proc = builder.start() val inputStream = proc.getInputStream @@ -172,10 +172,10 @@ case class ScriptTransformation( val fieldList = outputSoi.getAllStructFieldRefs() var i = 0 while (i < dataList.size()) { - if (dataList(i) == null) { + if (dataList.get(i) == null) { mutableRow.setNullAt(i) } else { - mutableRow(i) = unwrap(dataList(i), fieldList(i).getFieldObjectInspector) + mutableRow(i) = unwrap(dataList.get(i), fieldList.get(i).getFieldObjectInspector) } i += 1 } @@ -307,7 +307,7 @@ case class HiveScriptIOSchema ( val serde = initSerDe(serdeClass, columns, columnTypes, inputSerdeProps) val fieldObjectInspectors = columnTypes.map(toInspector) val objectInspector = ObjectInspectorFactory - .getStandardStructObjectInspector(columns, fieldObjectInspectors) + .getStandardStructObjectInspector(columns.asJava, fieldObjectInspectors.asJava) .asInstanceOf[ObjectInspector] (serde, objectInspector) } @@ -342,7 +342,7 @@ case class HiveScriptIOSchema ( propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) val properties = new Properties() - properties.putAll(propsMap) + properties.putAll(propsMap.asJava) serde.initialize(null, properties) serde diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 7182246e466a4..cad02373e5ba1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.Try import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ConstantObjectInspector} @@ -81,8 +81,7 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) /* List all of the registered function names. */ override def listFunction(): Seq[String] = { - val a = FunctionRegistry.getFunctionNames ++ underlying.listFunction() - a.toList.sorted + (FunctionRegistry.getFunctionNames.asScala ++ underlying.listFunction()).toList.sorted } /* Get the class of the registered function by specified name. */ @@ -116,7 +115,7 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre @transient private lazy val method = - function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo)) + function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo).asJava) @transient private lazy val arguments = children.map(toInspector).toArray @@ -541,7 +540,7 @@ private[hive] case class HiveGenericUDTF( @transient protected lazy val collector = new UDTFCollector - lazy val elementTypes = outputInspector.getAllStructFieldRefs.map { + lazy val elementTypes = outputInspector.getAllStructFieldRefs.asScala.map { field => (inspectorToDataType(field.getFieldObjectInspector), true) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 9f4f8b5789afe..1cff5cf9c3543 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.hive.orc import java.util.Properties +import scala.collection.JavaConverters._ + import com.google.common.base.Objects import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} @@ -43,9 +45,6 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.SerializableConfiguration -/* Implicit conversions */ -import scala.collection.JavaConversions._ - private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { override def shortName(): String = "orc" @@ -97,7 +96,8 @@ private[orc] class OrcOutputWriter( private val reusableOutputBuffer = new Array[Any](dataSchema.length) // Used to convert Catalyst values into Hadoop `Writable`s. - private val wrappers = structOI.getAllStructFieldRefs.zip(dataSchema.fields.map(_.dataType)) + private val wrappers = structOI.getAllStructFieldRefs.asScala + .zip(dataSchema.fields.map(_.dataType)) .map { case (ref, dt) => wrapperFor(ref.getFieldObjectInspector, dt) }.toArray diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 4da86636ac100..572eaebe81ac2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.test import java.io.File import java.util.{Set => JavaSet} +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.implicitConversions @@ -37,9 +38,6 @@ import org.apache.spark.sql.hive.execution.HiveNativeCommand import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.{SparkConf, SparkContext} -/* Implicit conversions */ -import scala.collection.JavaConversions._ - // SPARK-3729: Test key required to check for initialization errors with config. object TestHive extends TestHiveContext( @@ -405,7 +403,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { def reset() { try { // HACK: Hive is too noisy by default. - org.apache.log4j.LogManager.getCurrentLoggers.foreach { log => + org.apache.log4j.LogManager.getCurrentLoggers.asScala.foreach { log => log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) } @@ -415,9 +413,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { catalog.client.reset() catalog.unregisterAllTables() - FunctionRegistry.getFunctionNames.filterNot(originalUDFs.contains(_)).foreach { udfName => - FunctionRegistry.unregisterTemporaryUDF(udfName) - } + FunctionRegistry.getFunctionNames.asScala.filterNot(originalUDFs.contains(_)). + foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } // Some tests corrupt this value on purpose, which breaks the RESET call below. hiveconf.set("fs.default.name", new File(".").toURI.toString) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala index 0efcf80bd4ea7..5e7b93d457106 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.client -import scala.collection.JavaConversions._ +import java.util.Collections import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.serde.serdeConstants @@ -38,7 +38,7 @@ class FiltersSuite extends SparkFunSuite with Logging { private val varCharCol = new FieldSchema() varCharCol.setName("varchar") varCharCol.setType(serdeConstants.VARCHAR_TYPE_NAME) - testTable.setPartCols(varCharCol :: Nil) + testTable.setPartCols(Collections.singletonList(varCharCol)) filterTest("string filter", (a("stringcol", StringType) > Literal("test")) :: Nil, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index b03a35132325d..9c10ffe1113dc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.hive.execution import java.io.{DataInput, DataOutput} -import java.util -import java.util.Properties +import java.util.{ArrayList, Arrays, Properties} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFAverage, GenericUDF} @@ -33,8 +32,6 @@ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.util.Utils -import scala.collection.JavaConversions._ - case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int) // Case classes for the custom UDF's. @@ -326,11 +323,11 @@ class PairSerDe extends AbstractSerDe { override def getObjectInspector: ObjectInspector = { ObjectInspectorFactory .getStandardStructObjectInspector( - Seq("pair"), - Seq(ObjectInspectorFactory.getStandardStructObjectInspector( - Seq("id", "value"), - Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, - PrimitiveObjectInspectorFactory.javaIntObjectInspector)) + Arrays.asList("pair"), + Arrays.asList(ObjectInspectorFactory.getStandardStructObjectInspector( + Arrays.asList("id", "value"), + Arrays.asList(PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector)) )) } @@ -343,10 +340,10 @@ class PairSerDe extends AbstractSerDe { override def deserialize(value: Writable): AnyRef = { val pair = value.asInstanceOf[TestPair] - val row = new util.ArrayList[util.ArrayList[AnyRef]] - row.add(new util.ArrayList[AnyRef](2)) - row(0).add(Integer.valueOf(pair.entry._1)) - row(0).add(Integer.valueOf(pair.entry._2)) + val row = new ArrayList[ArrayList[AnyRef]] + row.add(new ArrayList[AnyRef](2)) + row.get(0).add(Integer.valueOf(pair.entry._1)) + row.get(0).add(Integer.valueOf(pair.entry._2)) row } @@ -355,9 +352,9 @@ class PairSerDe extends AbstractSerDe { class PairUDF extends GenericUDF { override def initialize(p1: Array[ObjectInspector]): ObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector( - Seq("id", "value"), - Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, - PrimitiveObjectInspectorFactory.javaIntObjectInspector) + Arrays.asList("id", "value"), + Arrays.asList(PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector) ) override def evaluate(args: Array[DeferredObject]): AnyRef = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 3bf8f3ac20480..210d566745415 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.hive.execution +import scala.collection.JavaConverters._ + import org.scalatest.BeforeAndAfter import org.apache.spark.sql.hive.test.TestHive -/* Implicit conversions */ -import scala.collection.JavaConversions._ - /** * A set of test cases that validate partition and column pruning. */ @@ -161,7 +160,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { assert(actualOutputColumns === expectedOutputColumns, "Output columns mismatch") assert(actualScannedColumns === expectedScannedColumns, "Scanned columns mismatch") - val actualPartitions = actualPartValues.map(_.toSeq.mkString(",")).sorted + val actualPartitions = actualPartValues.map(_.asScala.mkString(",")).sorted val expectedPartitions = expectedPartValues.map(_.mkString(",")).sorted assert(actualPartitions === expectedPartitions, "Partitions selected do not match") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 55ecbd5b5f21d..1ff1d9a2934cc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.execution import java.sql.{Date, Timestamp} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.DefaultParserDialect @@ -164,7 +164,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { test("show functions") { val allFunctions = (FunctionRegistry.builtin.listFunction().toSet[String] ++ - org.apache.hadoop.hive.ql.exec.FunctionRegistry.getFunctionNames).toList.sorted + org.apache.hadoop.hive.ql.exec.FunctionRegistry.getFunctionNames.asScala).toList.sorted checkAnswer(sql("SHOW functions"), allFunctions.map(Row(_))) checkAnswer(sql("SHOW functions abs"), Row("abs")) checkAnswer(sql("SHOW functions 'abs'"), Row("abs")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 5bbca14bad320..7966b43596e75 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.sources -import java.sql.Date - -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -552,7 +550,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } finally { // Hadoop 1 doesn't have `Configuration.unset` configuration.clear() - clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) } } @@ -600,7 +598,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } finally { // Hadoop 1 doesn't have `Configuration.unset` configuration.clear() - clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + clonedConf.asScala.foreach(entry => configuration.set(entry.getKey, entry.getValue)) sqlContext.sparkContext.conf.set("spark.speculation", speculationEnabled.toString) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index 214cd80108b9b..edfa474677f15 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -17,11 +17,10 @@ package org.apache.spark.streaming.api.java -import java.util import java.lang.{Long => JLong} import java.util.{List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -145,8 +144,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * an array. */ def glom(): JavaDStream[JList[T]] = - new JavaDStream(dstream.glom().map(x => new java.util.ArrayList[T](x.toSeq))) - + new JavaDStream(dstream.glom().map(_.toSeq.asJava)) /** Return the [[org.apache.spark.streaming.StreamingContext]] associated with this DStream */ @@ -191,7 +189,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaDStream[U] = { def fn: (Iterator[T]) => Iterator[U] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } new JavaDStream(dstream.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -204,7 +202,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]) : JavaPairDStream[K2, V2] = { def fn: (Iterator[T]) => Iterator[(K2, V2)] = { - (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + (x: Iterator[T]) => f.call(x.asJava).iterator().asScala } new JavaPairDStream(dstream.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } @@ -282,7 +280,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * Return all the RDDs between 'fromDuration' to 'toDuration' (both included) */ def slice(fromTime: Time, toTime: Time): JList[R] = { - new util.ArrayList(dstream.slice(fromTime, toTime).map(wrapRDD(_)).toSeq) + dstream.slice(fromTime, toTime).map(wrapRDD).asJava } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index 26383e420101e..e2aec6c2f63e7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.api.java import java.lang.{Long => JLong, Iterable => JIterable} import java.util.{List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag @@ -116,14 +116,14 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * generate the RDDs with Spark's default number of partitions. */ def groupByKey(): JavaPairDStream[K, JIterable[V]] = - dstream.groupByKey().mapValues(asJavaIterable _) + dstream.groupByKey().mapValues(_.asJava) /** * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to * generate the RDDs with `numPartitions` partitions. */ def groupByKey(numPartitions: Int): JavaPairDStream[K, JIterable[V]] = - dstream.groupByKey(numPartitions).mapValues(asJavaIterable _) + dstream.groupByKey(numPartitions).mapValues(_.asJava) /** * Return a new DStream by applying `groupByKey` on each RDD of `this` DStream. @@ -132,7 +132,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * is used to control the partitioning of each RDD. */ def groupByKey(partitioner: Partitioner): JavaPairDStream[K, JIterable[V]] = - dstream.groupByKey(partitioner).mapValues(asJavaIterable _) + dstream.groupByKey(partitioner).mapValues(_.asJava) /** * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are @@ -197,7 +197,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * batching interval */ def groupByKeyAndWindow(windowDuration: Duration): JavaPairDStream[K, JIterable[V]] = { - dstream.groupByKeyAndWindow(windowDuration).mapValues(asJavaIterable _) + dstream.groupByKeyAndWindow(windowDuration).mapValues(_.asJava) } /** @@ -212,7 +212,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( */ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration) : JavaPairDStream[K, JIterable[V]] = { - dstream.groupByKeyAndWindow(windowDuration, slideDuration).mapValues(asJavaIterable _) + dstream.groupByKeyAndWindow(windowDuration, slideDuration).mapValues(_.asJava) } /** @@ -228,8 +228,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( */ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) : JavaPairDStream[K, JIterable[V]] = { - dstream.groupByKeyAndWindow(windowDuration, slideDuration, numPartitions) - .mapValues(asJavaIterable _) + dstream.groupByKeyAndWindow(windowDuration, slideDuration, numPartitions).mapValues(_.asJava) } /** @@ -248,8 +247,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( slideDuration: Duration, partitioner: Partitioner ): JavaPairDStream[K, JIterable[V]] = { - dstream.groupByKeyAndWindow(windowDuration, slideDuration, partitioner) - .mapValues(asJavaIterable _) + dstream.groupByKeyAndWindow(windowDuration, slideDuration, partitioner).mapValues(_.asJava) } /** @@ -431,7 +429,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( private def convertUpdateStateFunction[S](in: JFunction2[JList[V], Optional[S], Optional[S]]): (Seq[V], Option[S]) => Option[S] = { val scalaFunc: (Seq[V], Option[S]) => Option[S] = (values, state) => { - val list: JList[V] = values + val list: JList[V] = values.asJava val scalaState: Optional[S] = JavaUtils.optionToOptional(state) val result: Optional[S] = in.apply(list, scalaState) result.isPresent match { @@ -539,7 +537,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( */ def cogroup[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (JIterable[V], JIterable[W])] = { implicit val cm: ClassTag[W] = fakeClassTag - dstream.cogroup(other.dstream).mapValues(t => (asJavaIterable(t._1), asJavaIterable((t._2)))) + dstream.cogroup(other.dstream).mapValues(t => (t._1.asJava, t._2.asJava)) } /** @@ -551,8 +549,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( numPartitions: Int ): JavaPairDStream[K, (JIterable[V], JIterable[W])] = { implicit val cm: ClassTag[W] = fakeClassTag - dstream.cogroup(other.dstream, numPartitions) - .mapValues(t => (asJavaIterable(t._1), asJavaIterable((t._2)))) + dstream.cogroup(other.dstream, numPartitions).mapValues(t => (t._1.asJava, t._2.asJava)) } /** @@ -564,8 +561,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( partitioner: Partitioner ): JavaPairDStream[K, (JIterable[V], JIterable[W])] = { implicit val cm: ClassTag[W] = fakeClassTag - dstream.cogroup(other.dstream, partitioner) - .mapValues(t => (asJavaIterable(t._1), asJavaIterable((t._2)))) + dstream.cogroup(other.dstream, partitioner).mapValues(t => (t._1.asJava, t._2.asJava)) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 35cc3ce5cf468..13f371f29603a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -21,7 +21,7 @@ import java.lang.{Boolean => JBoolean} import java.io.{Closeable, InputStream} import java.util.{List => JList, Map => JMap} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import akka.actor.{Props, SupervisorStrategy} @@ -115,7 +115,13 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { sparkHome: String, jars: Array[String], environment: JMap[String, String]) = - this(new StreamingContext(master, appName, batchDuration, sparkHome, jars, environment)) + this(new StreamingContext( + master, + appName, + batchDuration, + sparkHome, + jars, + environment.asScala)) /** * Create a JavaStreamingContext using an existing JavaSparkContext. @@ -197,7 +203,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { converter: JFunction[InputStream, java.lang.Iterable[T]], storageLevel: StorageLevel) : JavaReceiverInputDStream[T] = { - def fn: (InputStream) => Iterator[T] = (x: InputStream) => converter.call(x).toIterator + def fn: (InputStream) => Iterator[T] = (x: InputStream) => converter.call(x).iterator().asScala implicit val cmt: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] ssc.socketStream(hostname, port, fn, storageLevel) @@ -432,7 +438,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] val sQueue = new scala.collection.mutable.Queue[RDD[T]] - sQueue.enqueue(queue.map(_.rdd).toSeq: _*) + sQueue.enqueue(queue.asScala.map(_.rdd).toSeq: _*) ssc.queueStream(sQueue) } @@ -456,7 +462,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] val sQueue = new scala.collection.mutable.Queue[RDD[T]] - sQueue.enqueue(queue.map(_.rdd).toSeq: _*) + sQueue.enqueue(queue.asScala.map(_.rdd).toSeq: _*) ssc.queueStream(sQueue, oneAtATime) } @@ -481,7 +487,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] val sQueue = new scala.collection.mutable.Queue[RDD[T]] - sQueue.enqueue(queue.map(_.rdd).toSeq: _*) + sQueue.enqueue(queue.asScala.map(_.rdd).toSeq: _*) ssc.queueStream(sQueue, oneAtATime, defaultRDD.rdd) } @@ -500,7 +506,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create a unified DStream from multiple DStreams of the same type and same slide duration. */ def union[T](first: JavaDStream[T], rest: JList[JavaDStream[T]]): JavaDStream[T] = { - val dstreams: Seq[DStream[T]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.dstream) + val dstreams: Seq[DStream[T]] = (Seq(first) ++ rest.asScala).map(_.dstream) implicit val cm: ClassTag[T] = first.classTag ssc.union(dstreams)(cm) } @@ -512,7 +518,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { first: JavaPairDStream[K, V], rest: JList[JavaPairDStream[K, V]] ): JavaPairDStream[K, V] = { - val dstreams: Seq[DStream[(K, V)]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.dstream) + val dstreams: Seq[DStream[(K, V)]] = (Seq(first) ++ rest.asScala).map(_.dstream) implicit val cm: ClassTag[(K, V)] = first.classTag implicit val kcm: ClassTag[K] = first.kManifest implicit val vcm: ClassTag[V] = first.vManifest @@ -534,12 +540,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { ): JavaDStream[T] = { implicit val cmt: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val scalaDStreams = dstreams.map(_.dstream).toSeq val scalaTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { - val jrdds = rdds.map(rdd => JavaRDD.fromRDD[AnyRef](rdd.asInstanceOf[RDD[AnyRef]])).toList + val jrdds = rdds.map(JavaRDD.fromRDD(_)).asJava transformFunc.call(jrdds, time).rdd } - ssc.transform(scalaDStreams, scalaTransformFunc) + ssc.transform(dstreams.asScala.map(_.dstream).toSeq, scalaTransformFunc) } /** @@ -559,12 +564,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] implicit val cmv: ClassTag[V] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] - val scalaDStreams = dstreams.map(_.dstream).toSeq val scalaTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { - val jrdds = rdds.map(rdd => JavaRDD.fromRDD[AnyRef](rdd.asInstanceOf[RDD[AnyRef]])).toList + val jrdds = rdds.map(JavaRDD.fromRDD(_)).asJava transformFunc.call(jrdds, time).rdd } - ssc.transform(scalaDStreams, scalaTransformFunc) + ssc.transform(dstreams.asScala.map(_.dstream).toSeq, scalaTransformFunc) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index d06401245ff17..2c373640d2fd9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -20,14 +20,13 @@ package org.apache.spark.streaming.api.python import java.io.{ObjectInputStream, ObjectOutputStream} import java.lang.reflect.Proxy import java.util.{ArrayList => JArrayList, List => JList} -import scala.collection.JavaConversions._ + import scala.collection.JavaConverters._ import scala.language.existentials import py4j.GatewayServer import org.apache.spark.api.java._ -import org.apache.spark.api.python._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Interval, Duration, Time} @@ -161,7 +160,7 @@ private[python] object PythonDStream { */ def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = { val queue = new java.util.LinkedList[JavaRDD[Array[Byte]]] - rdds.forall(queue.add(_)) + rdds.asScala.foreach(queue.add) queue } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala index 554aae0117b24..2252e28f22af8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.receiver import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import org.apache.spark.storage.StorageLevel import org.apache.spark.annotation.DeveloperApi @@ -144,12 +144,12 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * for being used in the corresponding InputDStream. */ def store(dataIterator: java.util.Iterator[T], metadata: Any) { - supervisor.pushIterator(dataIterator, Some(metadata), None) + supervisor.pushIterator(dataIterator.asScala, Some(metadata), None) } /** Store an iterator of received data as a data block into Spark's memory. */ def store(dataIterator: java.util.Iterator[T]) { - supervisor.pushIterator(dataIterator, None, None) + supervisor.pushIterator(dataIterator.asScala, None, None) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 6d4cdc4aa6b10..0cd39594ee923 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.scheduler import java.util.concurrent.{ConcurrentHashMap, TimeUnit} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.util.{Failure, Success} import org.apache.spark.Logging @@ -128,7 +128,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } def getPendingTimes(): Seq[Time] = { - jobSets.keySet.toSeq + jobSets.asScala.keys.toSeq } def reportError(msg: String, e: Throwable) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index 53b96d51c9180..f2711d1355e60 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -19,6 +19,7 @@ package org.apache.spark.streaming.scheduler import java.nio.ByteBuffer +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.implicitConversions @@ -196,8 +197,7 @@ private[streaming] class ReceivedBlockTracker( writeAheadLogOption.foreach { writeAheadLog => logInfo(s"Recovering from write ahead logs in ${checkpointDirOption.get}") - import scala.collection.JavaConversions._ - writeAheadLog.readAll().foreach { byteBuffer => + writeAheadLog.readAll().asScala.foreach { byteBuffer => logTrace("Recovering record " + byteBuffer) Utils.deserialize[ReceivedBlockTrackerLogEvent]( byteBuffer.array, Thread.currentThread().getContextClassLoader) match { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index fe6328b1ce727..9f4a4d6806ab5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -19,6 +19,7 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer import java.util.{Iterator => JIterator} +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.concurrent.{Await, ExecutionContext, Future} import scala.language.postfixOps @@ -118,7 +119,6 @@ private[streaming] class FileBasedWriteAheadLog( * hence the implementation is kept simple. */ def readAll(): JIterator[ByteBuffer] = synchronized { - import scala.collection.JavaConversions._ val logFilesToRead = pastLogs.map{ _.path} ++ currentLogPath logInfo("Reading from the logs: " + logFilesToRead.mkString("\n")) @@ -126,7 +126,7 @@ private[streaming] class FileBasedWriteAheadLog( logDebug(s"Creating log reader with $file") val reader = new FileBasedWriteAheadLogReader(file, hadoopConf) CompletionIterator[ByteBuffer, Iterator[ByteBuffer]](reader, reader.close _) - } flatMap { x => x } + }.flatten.asJava } /** diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala index bb80bff6dc2e6..57b50bdfd6520 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala @@ -17,16 +17,13 @@ package org.apache.spark.streaming -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import java.util.{List => JList} + +import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import java.util.{List => JList} -import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStreamLike, JavaDStream, JavaStreamingContext} -import org.apache.spark.streaming._ -import java.util.ArrayList -import collection.JavaConversions._ import org.apache.spark.api.java.JavaRDDLike -import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.api.java.{JavaDStreamLike, JavaDStream, JavaStreamingContext} /** Exposes streaming test functionality in a Java-friendly way. */ trait JavaTestBase extends TestSuiteBase { @@ -39,7 +36,7 @@ trait JavaTestBase extends TestSuiteBase { ssc: JavaStreamingContext, data: JList[JList[T]], numPartitions: Int) = { - val seqData = data.map(Seq(_:_*)) + val seqData = data.asScala.map(_.asScala) implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] @@ -72,9 +69,7 @@ trait JavaTestBase extends TestSuiteBase { implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] ssc.getState() val res = runStreams[V](ssc.ssc, numBatches, numExpectedOutput) - val out = new ArrayList[JList[V]]() - res.map(entry => out.append(new ArrayList[V](entry))) - out + res.map(_.asJava).asJava } /** @@ -90,12 +85,7 @@ trait JavaTestBase extends TestSuiteBase { implicit val cm: ClassTag[V] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] val res = runStreamsWithPartitions[V](ssc.ssc, numBatches, numExpectedOutput) - val out = new ArrayList[JList[JList[V]]]() - res.map{entry => - val lists = entry.map(new ArrayList[V](_)) - out.append(new ArrayList[JList[V]](lists)) - } - out + res.map(entry => entry.map(_.asJava).asJava).asJava } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 325ff7c74c39d..5e49fd00769ad 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -20,6 +20,7 @@ import java.io._ import java.nio.ByteBuffer import java.util +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.{implicitConversions, postfixOps} @@ -417,9 +418,8 @@ object WriteAheadLogSuite { /** Read all the data in the log file in a directory using the WriteAheadLog class. */ def readDataUsingWriteAheadLog(logDirectory: String): Seq[String] = { - import scala.collection.JavaConversions._ val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1) - val data = wal.readAll().map(byteBufferToString).toSeq + val data = wal.readAll().asScala.map(byteBufferToString).toSeq wal.close() data } diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index 9418beb6b3e3a..a0524cabff2d4 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -22,7 +22,7 @@ import java.io.File import java.util.jar.JarFile import scala.collection.mutable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.runtimeMirror import scala.reflect.runtime.{universe => unv} import scala.util.Try @@ -161,7 +161,7 @@ object GenerateMIMAIgnore { val path = packageName.replace('.', '/') val resources = classLoader.getResources(path) - val jars = resources.filter(x => x.getProtocol == "jar") + val jars = resources.asScala.filter(_.getProtocol == "jar") .map(_.getFile.split(":")(1).split("!")(0)).toSeq jars.flatMap(getClassesFromJar(_, path)) @@ -175,7 +175,7 @@ object GenerateMIMAIgnore { private def getClassesFromJar(jarPath: String, packageName: String) = { import scala.collection.mutable val jar = new JarFile(new File(jarPath)) - val enums = jar.entries().map(_.getName).filter(_.startsWith(packageName)) + val enums = jar.entries().asScala.map(_.getName).filter(_.startsWith(packageName)) val classes = mutable.HashSet[Class[_]]() for (entry <- enums if entry.endsWith(".class")) { try { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index bff585b46cbbe..e9a02baafd28e 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -25,7 +25,7 @@ import java.security.PrivilegedExceptionAction import java.util.{Properties, UUID} import java.util.zip.{ZipEntry, ZipOutputStream} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map} import scala.reflect.runtime.universe import scala.util.{Try, Success, Failure} @@ -511,7 +511,7 @@ private[spark] class Client( val nns = YarnSparkHadoopUtil.get.getNameNodesToAccess(sparkConf) + stagingDirPath YarnSparkHadoopUtil.get.obtainTokensForNamenodes( nns, hadoopConf, creds, Some(sparkConf.get("spark.yarn.principal"))) - val t = creds.getAllTokens + val t = creds.getAllTokens.asScala .filter(_.getKind == DelegationTokenIdentifier.HDFS_DELEGATION_KIND) .head val newExpiration = t.renew(hadoopConf) @@ -650,8 +650,8 @@ private[spark] class Client( distCacheMgr.setDistArchivesEnv(launchEnv) val amContainer = Records.newRecord(classOf[ContainerLaunchContext]) - amContainer.setLocalResources(localResources) - amContainer.setEnvironment(launchEnv) + amContainer.setLocalResources(localResources.asJava) + amContainer.setEnvironment(launchEnv.asJava) val javaOpts = ListBuffer[String]() @@ -782,7 +782,7 @@ private[spark] class Client( // TODO: it would be nicer to just make sure there are no null commands here val printableCommands = commands.map(s => if (s == null) "null" else s).toList - amContainer.setCommands(printableCommands) + amContainer.setCommands(printableCommands.asJava) logDebug("===============================================================================") logDebug("YARN AM launch context:") @@ -797,7 +797,8 @@ private[spark] class Client( // send the acl settings into YARN to control who has access via YARN interfaces val securityManager = new SecurityManager(sparkConf) - amContainer.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityManager)) + amContainer.setApplicationACLs( + YarnSparkHadoopUtil.getApplicationAclsForYarn(securityManager).asJava) setupSecurityToken(amContainer) UserGroupInformation.getCurrentUser().addCredentials(credentials) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 4cc50483a17ff..9abd09b3cc7a5 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -20,14 +20,13 @@ package org.apache.spark.deploy.yarn import java.io.File import java.net.URI import java.nio.ByteBuffer +import java.util.Collections -import org.apache.hadoop.fs.Path -import org.apache.hadoop.yarn.api.ApplicationConstants.Environment -import org.apache.spark.util.Utils - -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, ListBuffer} +import org.apache.hadoop.fs.Path +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.DataOutputBuffer import org.apache.hadoop.security.UserGroupInformation @@ -40,6 +39,7 @@ import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.network.util.JavaUtils +import org.apache.spark.util.Utils class ExecutorRunnable( container: Container, @@ -74,9 +74,9 @@ class ExecutorRunnable( .asInstanceOf[ContainerLaunchContext] val localResources = prepareLocalResources - ctx.setLocalResources(localResources) + ctx.setLocalResources(localResources.asJava) - ctx.setEnvironment(env) + ctx.setEnvironment(env.asJava) val credentials = UserGroupInformation.getCurrentUser().getCredentials() val dob = new DataOutputBuffer() @@ -96,8 +96,9 @@ class ExecutorRunnable( |=============================================================================== """.stripMargin) - ctx.setCommands(commands) - ctx.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)) + ctx.setCommands(commands.asJava) + ctx.setApplicationACLs( + YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr).asJava) // If external shuffle service is enabled, register with the Yarn shuffle service already // started on the NodeManager and, if authentication is enabled, provide it with our secret @@ -112,7 +113,7 @@ class ExecutorRunnable( // Authentication is not enabled, so just provide dummy metadata ByteBuffer.allocate(0) } - ctx.setServiceData(Map[String, ByteBuffer]("spark_shuffle" -> secretBytes)) + ctx.setServiceData(Collections.singletonMap("spark_shuffle", secretBytes)) } // Send the start request to the ContainerManager @@ -314,7 +315,8 @@ class ExecutorRunnable( env("SPARK_LOG_URL_STDOUT") = s"$baseUrl/stdout?start=-4096" } - System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k, v) => env(k) = v } + System.getenv().asScala.filterKeys(_.startsWith("SPARK")) + .foreach { case (k, v) => env(k) = v } env } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index ccf753e69f4b6..5f897cbcb4e9f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -21,9 +21,7 @@ import java.util.Collections import java.util.concurrent._ import java.util.regex.Pattern -import org.apache.spark.util.Utils - -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import com.google.common.util.concurrent.ThreadFactoryBuilder @@ -39,8 +37,8 @@ import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ +import org.apache.spark.util.Utils /** * YarnAllocator is charged with requesting containers from the YARN ResourceManager and deciding @@ -164,7 +162,7 @@ private[yarn] class YarnAllocator( * Number of container requests at the given location that have not yet been fulfilled. */ private def getNumPendingAtLocation(location: String): Int = - amClient.getMatchingRequests(RM_REQUEST_PRIORITY, location, resource).map(_.size).sum + amClient.getMatchingRequests(RM_REQUEST_PRIORITY, location, resource).asScala.map(_.size).sum /** * Request as many executors from the ResourceManager as needed to reach the desired total. If @@ -231,14 +229,14 @@ private[yarn] class YarnAllocator( numExecutorsRunning, allocateResponse.getAvailableResources)) - handleAllocatedContainers(allocatedContainers) + handleAllocatedContainers(allocatedContainers.asScala) } val completedContainers = allocateResponse.getCompletedContainersStatuses() if (completedContainers.size > 0) { logDebug("Completed %d containers".format(completedContainers.size)) - processCompletedContainers(completedContainers) + processCompletedContainers(completedContainers.asScala) logDebug("Finished processing %d completed containers. Current running executor count: %d." .format(completedContainers.size, numExecutorsRunning)) @@ -271,7 +269,7 @@ private[yarn] class YarnAllocator( val request = createContainerRequest(resource, locality.nodes, locality.racks) amClient.addContainerRequest(request) val nodes = request.getNodes - val hostStr = if (nodes == null || nodes.isEmpty) "Any" else nodes.last + val hostStr = if (nodes == null || nodes.isEmpty) "Any" else nodes.asScala.last logInfo(s"Container request (host: $hostStr, capability: $resource)") } } else if (missing < 0) { @@ -280,7 +278,8 @@ private[yarn] class YarnAllocator( val matchingRequests = amClient.getMatchingRequests(RM_REQUEST_PRIORITY, ANY_HOST, resource) if (!matchingRequests.isEmpty) { - matchingRequests.head.take(numToCancel).foreach(amClient.removeContainerRequest) + matchingRequests.iterator().next().asScala + .take(numToCancel).foreach(amClient.removeContainerRequest) } else { logWarning("Expected to find pending requests, but found none.") } @@ -459,7 +458,7 @@ private[yarn] class YarnAllocator( } } - if (allocatedContainerToHostMap.containsKey(containerId)) { + if (allocatedContainerToHostMap.contains(containerId)) { val host = allocatedContainerToHostMap.get(containerId).get val containerSet = allocatedHostToContainersMap.get(host).get diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 4999f9c06210a..df042bf291de7 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -19,17 +19,15 @@ package org.apache.spark.deploy.yarn import java.util.{List => JList} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.{Map, Set} import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.hadoop.yarn.webapp.util.WebAppUtils import org.apache.spark.{Logging, SecurityManager, SparkConf} @@ -108,8 +106,8 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg val method = classOf[WebAppUtils].getMethod("getProxyHostsAndPortsForAmFilter", classOf[Configuration]) val proxies = method.invoke(null, conf).asInstanceOf[JList[String]] - val hosts = proxies.map { proxy => proxy.split(":")(0) } - val uriBases = proxies.map { proxy => prefix + proxy + proxyBase } + val hosts = proxies.asScala.map { proxy => proxy.split(":")(0) } + val uriBases = proxies.asScala.map { proxy => prefix + proxy + proxyBase } Map("PROXY_HOSTS" -> hosts.mkString(","), "PROXY_URI_BASES" -> uriBases.mkString(",")) } catch { case e: NoSuchMethodException => diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala index 128e996b71fe5..b4f8049bff577 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -21,7 +21,7 @@ import java.io.{File, FileOutputStream, OutputStreamWriter} import java.util.Properties import java.util.concurrent.TimeUnit -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files @@ -132,7 +132,7 @@ abstract class BaseYarnClusterSuite props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"") props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"") - yarnCluster.getConfig().foreach { e => + yarnCluster.getConfig.asScala.foreach { e => props.setProperty("spark.hadoop." + e.getKey(), e.getValue()) } @@ -149,7 +149,7 @@ abstract class BaseYarnClusterSuite props.store(writer, "Spark properties.") writer.close() - val extraJarArgs = if (!extraJars.isEmpty()) Seq("--jars", extraJars.mkString(",")) else Nil + val extraJarArgs = if (extraJars.nonEmpty) Seq("--jars", extraJars.mkString(",")) else Nil val mainArgs = if (klass.endsWith(".py")) { Seq(klass) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 0a5402c89e764..e7f2501e7899f 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.deploy.yarn import java.io.File import java.net.URI -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ HashMap => MutableHashMap } +import scala.collection.JavaConverters._ +import scala.collection.mutable.{HashMap => MutableHashMap} import scala.reflect.ClassTag import scala.util.Try @@ -38,7 +38,7 @@ import org.mockito.Matchers._ import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfterAll, Matchers} -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.util.Utils class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { @@ -201,7 +201,7 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { appContext.getClass.getMethods.filter(_.getName.equals("getApplicationTags")).foreach{ method => val tags = method.invoke(appContext).asInstanceOf[java.util.Set[String]] tags should contain allOf ("tag1", "dup", "tag2", "multi word") - tags.filter(!_.isEmpty).size should be (4) + tags.asScala.filter(_.nonEmpty).size should be (4) } appContext.getMaxAppAttempts should be (42) } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 128350b648992..5a4ea2ea2f4ff 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -21,7 +21,6 @@ import java.io.File import java.net.URL import scala.collection.mutable -import scala.collection.JavaConversions._ import com.google.common.base.Charsets.UTF_8 import com.google.common.io.{ByteStreams, Files} @@ -216,8 +215,8 @@ private object YarnClusterDriver extends Logging with Matchers { assert(listener.driverLogs.nonEmpty) val driverLogs = listener.driverLogs.get assert(driverLogs.size === 2) - assert(driverLogs.containsKey("stderr")) - assert(driverLogs.containsKey("stdout")) + assert(driverLogs.contains("stderr")) + assert(driverLogs.contains("stdout")) val urlStr = driverLogs("stderr") // Ensure that this is a valid URL, else this will throw an exception new URL(urlStr) From 5c08c86bfa43462fb2ca5f7c5980ddfb44dd57f8 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 25 Aug 2015 10:22:54 -0700 Subject: [PATCH 072/802] [SPARK-10198] [SQL] Turn off partition verification by default Author: Michael Armbrust Closes #8404 from marmbrus/turnOffPartitionVerification. --- .../scala/org/apache/spark/sql/SQLConf.scala | 2 +- .../spark/sql/hive/QueryPartitionSuite.scala | 64 ++++++++++--------- 2 files changed, 35 insertions(+), 31 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index e6f7619519e6a..9de75f4c4d084 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -312,7 +312,7 @@ private[spark] object SQLConf { doc = "When true, enable filter pushdown for ORC files.") val HIVE_VERIFY_PARTITION_PATH = booleanConf("spark.sql.hive.verifyPartitionPath", - defaultValue = Some(true), + defaultValue = Some(false), doc = "") val HIVE_METASTORE_PARTITION_PRUNING = booleanConf("spark.sql.hive.metastorePartitionPruning", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index 017bc2adc103b..1cc8a93e83411 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -18,50 +18,54 @@ package org.apache.spark.sql.hive import com.google.common.io.Files +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.{QueryTest, _} import org.apache.spark.util.Utils -class QueryPartitionSuite extends QueryTest { +class QueryPartitionSuite extends QueryTest with SQLTestUtils { private lazy val ctx = org.apache.spark.sql.hive.test.TestHive import ctx.implicits._ - import ctx.sql + + protected def _sqlContext = ctx test("SPARK-5068: query data when path doesn't exist"){ - val testData = ctx.sparkContext.parallelize( - (1 to 10).map(i => TestData(i, i.toString))).toDF() - testData.registerTempTable("testData") + withSQLConf((SQLConf.HIVE_VERIFY_PARTITION_PATH.key, "true")) { + val testData = ctx.sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString))).toDF() + testData.registerTempTable("testData") - val tmpDir = Files.createTempDir() - // create the table for test - sql(s"CREATE TABLE table_with_partition(key int,value string) " + - s"PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + - "SELECT key,value FROM testData") - sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + - "SELECT key,value FROM testData") + val tmpDir = Files.createTempDir() + // create the table for test + sql(s"CREATE TABLE table_with_partition(key int,value string) " + + s"PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='1') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='2') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='3') " + + "SELECT key,value FROM testData") + sql("INSERT OVERWRITE TABLE table_with_partition partition (ds='4') " + + "SELECT key,value FROM testData") - // test for the exist path - checkAnswer(sql("select key,value from table_with_partition"), - testData.toDF.collect ++ testData.toDF.collect - ++ testData.toDF.collect ++ testData.toDF.collect) + // test for the exist path + checkAnswer(sql("select key,value from table_with_partition"), + testData.toDF.collect ++ testData.toDF.collect + ++ testData.toDF.collect ++ testData.toDF.collect) - // delete the path of one partition - tmpDir.listFiles - .find { f => f.isDirectory && f.getName().startsWith("ds=") } - .foreach { f => Utils.deleteRecursively(f) } + // delete the path of one partition + tmpDir.listFiles + .find { f => f.isDirectory && f.getName().startsWith("ds=") } + .foreach { f => Utils.deleteRecursively(f) } - // test for after delete the path - checkAnswer(sql("select key,value from table_with_partition"), - testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) + // test for after delete the path + checkAnswer(sql("select key,value from table_with_partition"), + testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) - sql("DROP TABLE table_with_partition") - sql("DROP TABLE createAndInsertTest") + sql("DROP TABLE table_with_partition") + sql("DROP TABLE createAndInsertTest") + } } } From b37f0cc1b4c064d6f09edb161250fa8b783de52a Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 25 Aug 2015 10:54:03 -0700 Subject: [PATCH 073/802] [SPARK-8531] [ML] Update ML user guide for MinMaxScaler jira: https://issues.apache.org/jira/browse/SPARK-8531 Update ML user guide for MinMaxScaler Author: Yuhao Yang Author: unknown Closes #7211 from hhbyyh/minmaxdoc. --- docs/ml-features.md | 71 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/docs/ml-features.md b/docs/ml-features.md index 642a4b4c53183..62de4838981cb 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1133,6 +1133,7 @@ val scaledData = scalerModel.transform(dataFrame) {% highlight java %} import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.StandardScaler; +import org.apache.spark.ml.feature.StandardScalerModel; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.sql.DataFrame; @@ -1173,6 +1174,76 @@ scaledData = scalerModel.transform(dataFrame)
+## MinMaxScaler + +`MinMaxScaler` transforms a dataset of `Vector` rows, rescaling each feature to a specific range (often [0, 1]). It takes parameters: + +* `min`: 0.0 by default. Lower bound after transformation, shared by all features. +* `max`: 1.0 by default. Upper bound after transformation, shared by all features. + +`MinMaxScaler` computes summary statistics on a data set and produces a `MinMaxScalerModel`. The model can then transform each feature individually such that it is in the given range. + +The rescaled value for a feature E is calculated as, +`\begin{equation} + Rescaled(e_i) = \frac{e_i - E_{min}}{E_{max} - E_{min}} * (max - min) + min +\end{equation}` +For the case `E_{max} == E_{min}`, `Rescaled(e_i) = 0.5 * (max + min)` + +Note that since zero values will probably be transformed to non-zero values, output of the transformer will be DenseVector even for sparse input. + +The following example demonstrates how to load a dataset in libsvm format and then rescale each feature to [0, 1]. + +
+
+More details can be found in the API docs for +[MinMaxScaler](api/scala/index.html#org.apache.spark.ml.feature.MinMaxScaler) and +[MinMaxScalerModel](api/scala/index.html#org.apache.spark.ml.feature.MinMaxScalerModel). +{% highlight scala %} +import org.apache.spark.ml.feature.MinMaxScaler +import org.apache.spark.mllib.util.MLUtils + +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") +val dataFrame = sqlContext.createDataFrame(data) +val scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + +// Compute summary statistics and generate MinMaxScalerModel +val scalerModel = scaler.fit(dataFrame) + +// rescale each feature to range [min, max]. +val scaledData = scalerModel.transform(dataFrame) +{% endhighlight %} +
+ +
+More details can be found in the API docs for +[MinMaxScaler](api/java/org/apache/spark/ml/feature/MinMaxScaler.html) and +[MinMaxScalerModel](api/java/org/apache/spark/ml/feature/MinMaxScalerModel.html). +{% highlight java %} +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.MinMaxScaler; +import org.apache.spark.ml.feature.MinMaxScalerModel; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.sql.DataFrame; + +JavaRDD data = + MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt").toJavaRDD(); +DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class); +MinMaxScaler scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures"); + +// Compute summary statistics and generate MinMaxScalerModel +MinMaxScalerModel scalerModel = scaler.fit(dataFrame); + +// rescale each feature to range [min, max]. +DataFrame scaledData = scalerModel.transform(dataFrame); +{% endhighlight %} +
+
+ ## Bucketizer `Bucketizer` transforms a column of continuous features to a column of feature buckets, where the buckets are specified by users. It takes a parameter: From 881208a8e849facf54166bdd69d3634407f952e7 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Tue, 25 Aug 2015 11:58:47 -0700 Subject: [PATCH 074/802] [SPARK-10230] [MLLIB] Rename optimizeAlpha to optimizeDocConcentration See [discussion](https://github.com/apache/spark/pull/8254#discussion_r37837770) CC jkbradley Author: Feynman Liang Closes #8422 from feynmanliang/SPARK-10230. --- .../spark/mllib/clustering/LDAOptimizer.scala | 16 ++++++++-------- .../apache/spark/mllib/clustering/LDASuite.scala | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 5c2aae6403bea..38486e949bbcf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -258,7 +258,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { private var tau0: Double = 1024 private var kappa: Double = 0.51 private var miniBatchFraction: Double = 0.05 - private var optimizeAlpha: Boolean = false + private var optimizeDocConcentration: Boolean = false // internal data structure private var docs: RDD[(Long, Vector)] = null @@ -335,20 +335,20 @@ final class OnlineLDAOptimizer extends LDAOptimizer { } /** - * Optimize alpha, indicates whether alpha (Dirichlet parameter for document-topic distribution) - * will be optimized during training. + * Optimize docConcentration, indicates whether docConcentration (Dirichlet parameter for + * document-topic distribution) will be optimized during training. */ @Since("1.5.0") - def getOptimzeAlpha: Boolean = this.optimizeAlpha + def getOptimizeDocConcentration: Boolean = this.optimizeDocConcentration /** - * Sets whether to optimize alpha parameter during training. + * Sets whether to optimize docConcentration parameter during training. * * Default: false */ @Since("1.5.0") - def setOptimzeAlpha(optimizeAlpha: Boolean): this.type = { - this.optimizeAlpha = optimizeAlpha + def setOptimizeDocConcentration(optimizeDocConcentration: Boolean): this.type = { + this.optimizeDocConcentration = optimizeDocConcentration this } @@ -458,7 +458,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { // Note that this is an optimization to avoid batch.count updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt) - if (optimizeAlpha) updateAlpha(gammat) + if (optimizeDocConcentration) updateAlpha(gammat) this } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 8a714f9b79e02..746a76a7e5fa1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -423,7 +423,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { val k = 2 val docs = sc.parallelize(toyData) val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51) - .setGammaShape(100).setOptimzeAlpha(true).setSampleWithReplacement(false) + .setGammaShape(100).setOptimizeDocConcentration(true).setSampleWithReplacement(false) val lda = new LDA().setK(k) .setDocConcentration(1D / k) .setTopicConcentration(0.01) From 16a2be1a84c0a274a60c0a584faaf58b55d4942b Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 25 Aug 2015 12:16:23 -0700 Subject: [PATCH 075/802] [SPARK-10231] [MLLIB] update @Since annotation for mllib.classification Update `Since` annotation in `mllib.classification`: 1. add version to classes, objects, constructors, and public variables declared in constructors 2. correct some versions 3. remove `Since` on `toString` MechCoder dbtsai Author: Xiangrui Meng Closes #8421 from mengxr/SPARK-10231 and squashes the following commits: b2dce80 [Xiangrui Meng] update @Since annotation for mllib.classification --- .../classification/ClassificationModel.scala | 7 +++-- .../classification/LogisticRegression.scala | 20 +++++++++---- .../mllib/classification/NaiveBayes.scala | 28 +++++++++++++++---- .../spark/mllib/classification/SVM.scala | 15 ++++++---- .../StreamingLogisticRegressionWithSGD.scala | 9 +++++- 5 files changed, 58 insertions(+), 21 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala index a29b425a71fd6..85a413243b049 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD * belongs. The categories are represented by double values: 0.0, 1.0, 2.0, etc. */ @Experimental +@Since("0.8.0") trait ClassificationModel extends Serializable { /** * Predict values for the given data set using the model trained. @@ -37,7 +38,7 @@ trait ClassificationModel extends Serializable { * @param testData RDD representing data points to be predicted * @return an RDD[Double] where each entry contains the corresponding prediction */ - @Since("0.8.0") + @Since("1.0.0") def predict(testData: RDD[Vector]): RDD[Double] /** @@ -46,7 +47,7 @@ trait ClassificationModel extends Serializable { * @param testData array representing a single data point * @return predicted category from the trained model */ - @Since("0.8.0") + @Since("1.0.0") def predict(testData: Vector): Double /** @@ -54,7 +55,7 @@ trait ClassificationModel extends Serializable { * @param testData JavaRDD representing data points to be predicted * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction */ - @Since("0.8.0") + @Since("1.0.0") def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index e03e662227d14..5ceff5b2259ea 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -41,11 +41,12 @@ import org.apache.spark.rdd.RDD * Multinomial Logistic Regression. By default, it is binary logistic regression * so numClasses will be set to 2. */ -class LogisticRegressionModel ( - override val weights: Vector, - override val intercept: Double, - val numFeatures: Int, - val numClasses: Int) +@Since("0.8.0") +class LogisticRegressionModel @Since("1.3.0") ( + @Since("1.0.0") override val weights: Vector, + @Since("1.0.0") override val intercept: Double, + @Since("1.3.0") val numFeatures: Int, + @Since("1.3.0") val numClasses: Int) extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable with Saveable with PMMLExportable { @@ -75,6 +76,7 @@ class LogisticRegressionModel ( /** * Constructs a [[LogisticRegressionModel]] with weights and intercept for binary classification. */ + @Since("1.0.0") def this(weights: Vector, intercept: Double) = this(weights, intercept, weights.size, 2) private var threshold: Option[Double] = Some(0.5) @@ -166,12 +168,12 @@ class LogisticRegressionModel ( override protected def formatVersion: String = "1.0" - @Since("1.4.0") override def toString: String = { s"${super.toString}, numClasses = ${numClasses}, threshold = ${threshold.getOrElse("None")}" } } +@Since("1.3.0") object LogisticRegressionModel extends Loader[LogisticRegressionModel] { @Since("1.3.0") @@ -207,6 +209,7 @@ object LogisticRegressionModel extends Loader[LogisticRegressionModel] { * for k classes multi-label classification problem. * Using [[LogisticRegressionWithLBFGS]] is recommended over this. */ +@Since("0.8.0") class LogisticRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, @@ -216,6 +219,7 @@ class LogisticRegressionWithSGD private[mllib] ( private val gradient = new LogisticGradient() private val updater = new SquaredL2Updater() + @Since("0.8.0") override val optimizer = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) @@ -227,6 +231,7 @@ class LogisticRegressionWithSGD private[mllib] ( * Construct a LogisticRegression object with default parameters: {stepSize: 1.0, * numIterations: 100, regParm: 0.01, miniBatchFraction: 1.0}. */ + @Since("0.8.0") def this() = this(1.0, 100, 0.01, 1.0) override protected[mllib] def createModel(weights: Vector, intercept: Double) = { @@ -238,6 +243,7 @@ class LogisticRegressionWithSGD private[mllib] ( * Top-level methods for calling Logistic Regression using Stochastic Gradient Descent. * NOTE: Labels used in Logistic Regression should be {0, 1} */ +@Since("0.8.0") object LogisticRegressionWithSGD { // NOTE(shivaram): We use multiple train methods instead of default arguments to support // Java programs. @@ -333,11 +339,13 @@ object LogisticRegressionWithSGD { * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1} * for k classes multi-label classification problem. */ +@Since("1.1.0") class LogisticRegressionWithLBFGS extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable { this.setFeatureScaling(true) + @Since("1.1.0") override val optimizer = new LBFGS(new LogisticGradient, new SquaredL2Updater) override protected val validators = List(multiLabelValidator) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index dab369207cc9a..a956084ae06e8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -41,11 +41,12 @@ import org.apache.spark.sql.{DataFrame, SQLContext} * where D is number of features * @param modelType The type of NB model to fit can be "multinomial" or "bernoulli" */ +@Since("0.9.0") class NaiveBayesModel private[spark] ( - val labels: Array[Double], - val pi: Array[Double], - val theta: Array[Array[Double]], - val modelType: String) + @Since("1.0.0") val labels: Array[Double], + @Since("0.9.0") val pi: Array[Double], + @Since("0.9.0") val theta: Array[Array[Double]], + @Since("1.4.0") val modelType: String) extends ClassificationModel with Serializable with Saveable { import NaiveBayes.{Bernoulli, Multinomial, supportedModelTypes} @@ -83,6 +84,7 @@ class NaiveBayesModel private[spark] ( throw new UnknownError(s"Invalid modelType: $modelType.") } + @Since("1.0.0") override def predict(testData: RDD[Vector]): RDD[Double] = { val bcModel = testData.context.broadcast(this) testData.mapPartitions { iter => @@ -91,6 +93,7 @@ class NaiveBayesModel private[spark] ( } } + @Since("1.0.0") override def predict(testData: Vector): Double = { modelType match { case Multinomial => @@ -107,6 +110,7 @@ class NaiveBayesModel private[spark] ( * @return an RDD[Vector] where each entry contains the predicted posterior class probabilities, * in the same order as class labels */ + @Since("1.5.0") def predictProbabilities(testData: RDD[Vector]): RDD[Vector] = { val bcModel = testData.context.broadcast(this) testData.mapPartitions { iter => @@ -122,6 +126,7 @@ class NaiveBayesModel private[spark] ( * @return predicted posterior class probabilities from the trained model, * in the same order as class labels */ + @Since("1.5.0") def predictProbabilities(testData: Vector): Vector = { modelType match { case Multinomial => @@ -158,6 +163,7 @@ class NaiveBayesModel private[spark] ( new DenseVector(scaledProbs.map(_ / probSum)) } + @Since("1.3.0") override def save(sc: SparkContext, path: String): Unit = { val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType) NaiveBayesModel.SaveLoadV2_0.save(sc, path, data) @@ -166,6 +172,7 @@ class NaiveBayesModel private[spark] ( override protected def formatVersion: String = "2.0" } +@Since("1.3.0") object NaiveBayesModel extends Loader[NaiveBayesModel] { import org.apache.spark.mllib.util.Loader._ @@ -199,6 +206,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { dataRDD.write.parquet(dataPath(path)) } + @Since("1.3.0") def load(sc: SparkContext, path: String): NaiveBayesModel = { val sqlContext = new SQLContext(sc) // Load Parquet data. @@ -301,30 +309,35 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { * document classification. By making every vector a 0-1 vector, it can also be used as * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative. */ - +@Since("0.9.0") class NaiveBayes private ( private var lambda: Double, private var modelType: String) extends Serializable with Logging { import NaiveBayes.{Bernoulli, Multinomial} + @Since("1.4.0") def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial) + @Since("0.9.0") def this() = this(1.0, NaiveBayes.Multinomial) /** Set the smoothing parameter. Default: 1.0. */ + @Since("0.9.0") def setLambda(lambda: Double): NaiveBayes = { this.lambda = lambda this } /** Get the smoothing parameter. */ + @Since("1.4.0") def getLambda: Double = lambda /** * Set the model type using a string (case-sensitive). * Supported options: "multinomial" (default) and "bernoulli". */ + @Since("1.4.0") def setModelType(modelType: String): NaiveBayes = { require(NaiveBayes.supportedModelTypes.contains(modelType), s"NaiveBayes was created with an unknown modelType: $modelType.") @@ -333,6 +346,7 @@ class NaiveBayes private ( } /** Get the model type. */ + @Since("1.4.0") def getModelType: String = this.modelType /** @@ -340,6 +354,7 @@ class NaiveBayes private ( * * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. */ + @Since("0.9.0") def run(data: RDD[LabeledPoint]): NaiveBayesModel = { val requireNonnegativeValues: Vector => Unit = (v: Vector) => { val values = v match { @@ -423,6 +438,7 @@ class NaiveBayes private ( /** * Top-level methods for calling naive Bayes. */ +@Since("0.9.0") object NaiveBayes { /** String name for multinomial model type. */ @@ -485,7 +501,7 @@ object NaiveBayes { * @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be * multinomial or bernoulli */ - @Since("0.9.0") + @Since("1.4.0") def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = { require(supportedModelTypes.contains(modelType), s"NaiveBayes was created with an unknown modelType: $modelType.") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 5f87269863572..896565cd90e89 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -33,9 +33,10 @@ import org.apache.spark.rdd.RDD * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. */ -class SVMModel ( - override val weights: Vector, - override val intercept: Double) +@Since("0.8.0") +class SVMModel @Since("1.1.0") ( + @Since("1.0.0") override val weights: Vector, + @Since("0.8.0") override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable with Saveable with PMMLExportable { @@ -47,7 +48,7 @@ class SVMModel ( * with prediction score greater than or equal to this threshold is identified as an positive, * and negative otherwise. The default value is 0.0. */ - @Since("1.3.0") + @Since("1.0.0") @Experimental def setThreshold(threshold: Double): this.type = { this.threshold = Some(threshold) @@ -92,12 +93,12 @@ class SVMModel ( override protected def formatVersion: String = "1.0" - @Since("1.4.0") override def toString: String = { s"${super.toString}, numClasses = 2, threshold = ${threshold.getOrElse("None")}" } } +@Since("1.3.0") object SVMModel extends Loader[SVMModel] { @Since("1.3.0") @@ -132,6 +133,7 @@ object SVMModel extends Loader[SVMModel] { * regularization is used, which can be changed via [[SVMWithSGD.optimizer]]. * NOTE: Labels used in SVM should be {0, 1}. */ +@Since("0.8.0") class SVMWithSGD private ( private var stepSize: Double, private var numIterations: Int, @@ -141,6 +143,7 @@ class SVMWithSGD private ( private val gradient = new HingeGradient() private val updater = new SquaredL2Updater() + @Since("0.8.0") override val optimizer = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) @@ -152,6 +155,7 @@ class SVMWithSGD private ( * Construct a SVM object with default parameters: {stepSize: 1.0, numIterations: 100, * regParm: 0.01, miniBatchFraction: 1.0}. */ + @Since("0.8.0") def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { @@ -162,6 +166,7 @@ class SVMWithSGD private ( /** * Top-level methods for calling SVM. NOTE: Labels used in SVM should be {0, 1}. */ +@Since("0.8.0") object SVMWithSGD { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala index 7d33df3221fbf..75630054d1368 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.classification -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.StreamingLinearAlgorithm @@ -44,6 +44,7 @@ import org.apache.spark.mllib.regression.StreamingLinearAlgorithm * }}} */ @Experimental +@Since("1.3.0") class StreamingLogisticRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, @@ -58,6 +59,7 @@ class StreamingLogisticRegressionWithSGD private[mllib] ( * Initial weights must be set before using trainOn or predictOn * (see `StreamingLinearAlgorithm`) */ + @Since("1.3.0") def this() = this(0.1, 50, 1.0, 0.0) protected val algorithm = new LogisticRegressionWithSGD( @@ -66,30 +68,35 @@ class StreamingLogisticRegressionWithSGD private[mllib] ( protected var model: Option[LogisticRegressionModel] = None /** Set the step size for gradient descent. Default: 0.1. */ + @Since("1.3.0") def setStepSize(stepSize: Double): this.type = { this.algorithm.optimizer.setStepSize(stepSize) this } /** Set the number of iterations of gradient descent to run per update. Default: 50. */ + @Since("1.3.0") def setNumIterations(numIterations: Int): this.type = { this.algorithm.optimizer.setNumIterations(numIterations) this } /** Set the fraction of each batch to use for updates. Default: 1.0. */ + @Since("1.3.0") def setMiniBatchFraction(miniBatchFraction: Double): this.type = { this.algorithm.optimizer.setMiniBatchFraction(miniBatchFraction) this } /** Set the regularization parameter. Default: 0.0. */ + @Since("1.3.0") def setRegParam(regParam: Double): this.type = { this.algorithm.optimizer.setRegParam(regParam) this } /** Set the initial weights. Default: [0.0, 0.0]. */ + @Since("1.3.0") def setInitialWeights(initialWeights: Vector): this.type = { this.model = Some(algorithm.createModel(initialWeights, 0.0)) this From 71a138cd0e0a14e8426f97877e3b52a562bbd02c Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Tue, 25 Aug 2015 13:14:10 -0700 Subject: [PATCH 076/802] [SPARK-10048] [SPARKR] Support arbitrary nested Java array in serde. This PR: 1. supports transferring arbitrary nested array from JVM to R side in SerDe; 2. based on 1, collect() implemenation is improved. Now it can support collecting data of complex types from a DataFrame. Author: Sun Rui Closes #8276 from sun-rui/SPARK-10048. --- R/pkg/R/DataFrame.R | 55 +++++++++--- R/pkg/R/deserialize.R | 72 +++++++--------- R/pkg/R/serialize.R | 10 +-- R/pkg/inst/tests/test_Serde.R | 77 +++++++++++++++++ R/pkg/inst/worker/worker.R | 4 +- .../apache/spark/api/r/RBackendHandler.scala | 7 ++ .../scala/org/apache/spark/api/r/SerDe.scala | 86 +++++++++++-------- .../org/apache/spark/sql/api/r/SQLUtils.scala | 32 +------ 8 files changed, 216 insertions(+), 127 deletions(-) create mode 100644 R/pkg/inst/tests/test_Serde.R diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 10f3c4ea59864..ae1d912cf6da1 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -652,18 +652,49 @@ setMethod("dim", setMethod("collect", signature(x = "DataFrame"), function(x, stringsAsFactors = FALSE) { - # listCols is a list of raw vectors, one per column - listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf) - cols <- lapply(listCols, function(col) { - objRaw <- rawConnection(col) - numRows <- readInt(objRaw) - col <- readCol(objRaw, numRows) - close(objRaw) - col - }) - names(cols) <- columns(x) - do.call(cbind.data.frame, list(cols, stringsAsFactors = stringsAsFactors)) - }) + names <- columns(x) + ncol <- length(names) + if (ncol <= 0) { + # empty data.frame with 0 columns and 0 rows + data.frame() + } else { + # listCols is a list of columns + listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf) + stopifnot(length(listCols) == ncol) + + # An empty data.frame with 0 columns and number of rows as collected + nrow <- length(listCols[[1]]) + if (nrow <= 0) { + df <- data.frame() + } else { + df <- data.frame(row.names = 1 : nrow) + } + + # Append columns one by one + for (colIndex in 1 : ncol) { + # Note: appending a column of list type into a data.frame so that + # data of complex type can be held. But getting a cell from a column + # of list type returns a list instead of a vector. So for columns of + # non-complex type, append them as vector. + col <- listCols[[colIndex]] + if (length(col) <= 0) { + df[[names[colIndex]]] <- col + } else { + # TODO: more robust check on column of primitive types + vec <- do.call(c, col) + if (class(vec) != "list") { + df[[names[colIndex]]] <- vec + } else { + # For columns of complex type, be careful to access them. + # Get a column of complex type returns a list. + # Get a cell from a column of complex type returns a list instead of a vector. + df[[names[colIndex]]] <- col + } + } + } + df + } + }) #' Limit #' diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index 33bf13ec9e784..6cf628e3007de 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -48,6 +48,7 @@ readTypedObject <- function(con, type) { "r" = readRaw(con), "D" = readDate(con), "t" = readTime(con), + "a" = readArray(con), "l" = readList(con), "n" = NULL, "j" = getJobj(readString(con)), @@ -85,8 +86,7 @@ readTime <- function(con) { as.POSIXct(t, origin = "1970-01-01") } -# We only support lists where all elements are of same type -readList <- function(con) { +readArray <- function(con) { type <- readType(con) len <- readInt(con) if (len > 0) { @@ -100,6 +100,25 @@ readList <- function(con) { } } +# Read a list. Types of each element may be different. +# Null objects are read as NA. +readList <- function(con) { + len <- readInt(con) + if (len > 0) { + l <- vector("list", len) + for (i in 1:len) { + elem <- readObject(con) + if (is.null(elem)) { + elem <- NA + } + l[[i]] <- elem + } + l + } else { + list() + } +} + readRaw <- function(con) { dataLen <- readInt(con) readBin(con, raw(), as.integer(dataLen), endian = "big") @@ -132,18 +151,19 @@ readDeserialize <- function(con) { } } -readDeserializeRows <- function(inputCon) { - # readDeserializeRows will deserialize a DataOutputStream composed of - # a list of lists. Since the DOS is one continuous stream and - # the number of rows varies, we put the readRow function in a while loop - # that termintates when the next row is empty. +readMultipleObjects <- function(inputCon) { + # readMultipleObjects will read multiple continuous objects from + # a DataOutputStream. There is no preceding field telling the count + # of the objects, so the number of objects varies, we try to read + # all objects in a loop until the end of the stream. data <- list() while(TRUE) { - row <- readRow(inputCon) - if (length(row) == 0) { + # If reaching the end of the stream, type returned should be "". + type <- readType(inputCon) + if (type == "") { break } - data[[length(data) + 1L]] <- row + data[[length(data) + 1L]] <- readTypedObject(inputCon, type) } data # this is a list of named lists now } @@ -155,35 +175,5 @@ readRowList <- function(obj) { # deserialize the row. rawObj <- rawConnection(obj, "r+") on.exit(close(rawObj)) - readRow(rawObj) -} - -readRow <- function(inputCon) { - numCols <- readInt(inputCon) - if (length(numCols) > 0 && numCols > 0) { - lapply(1:numCols, function(x) { - obj <- readObject(inputCon) - if (is.null(obj)) { - NA - } else { - obj - } - }) # each row is a list now - } else { - list() - } -} - -# Take a single column as Array[Byte] and deserialize it into an atomic vector -readCol <- function(inputCon, numRows) { - if (numRows > 0) { - # sapply can not work with POSIXlt - do.call(c, lapply(1:numRows, function(x) { - value <- readObject(inputCon) - # Replace NULL with NA so we can coerce to vectors - if (is.null(value)) NA else value - })) - } else { - vector() - } + readObject(rawObj) } diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 311021e5d8473..e3676f57f907f 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -110,18 +110,10 @@ writeRowSerialize <- function(outputCon, rows) { serializeRow <- function(row) { rawObj <- rawConnection(raw(0), "wb") on.exit(close(rawObj)) - writeRow(rawObj, row) + writeGenericList(rawObj, row) rawConnectionValue(rawObj) } -writeRow <- function(con, row) { - numCols <- length(row) - writeInt(con, numCols) - for (i in 1:numCols) { - writeObject(con, row[[i]]) - } -} - writeRaw <- function(con, batch) { writeInt(con, length(batch)) writeBin(batch, con, endian = "big") diff --git a/R/pkg/inst/tests/test_Serde.R b/R/pkg/inst/tests/test_Serde.R new file mode 100644 index 0000000000000..009db85da2beb --- /dev/null +++ b/R/pkg/inst/tests/test_Serde.R @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("SerDe functionality") + +sc <- sparkR.init() + +test_that("SerDe of primitive types", { + x <- callJStatic("SparkRHandler", "echo", 1L) + expect_equal(x, 1L) + expect_equal(class(x), "integer") + + x <- callJStatic("SparkRHandler", "echo", 1) + expect_equal(x, 1) + expect_equal(class(x), "numeric") + + x <- callJStatic("SparkRHandler", "echo", TRUE) + expect_true(x) + expect_equal(class(x), "logical") + + x <- callJStatic("SparkRHandler", "echo", "abc") + expect_equal(x, "abc") + expect_equal(class(x), "character") +}) + +test_that("SerDe of list of primitive types", { + x <- list(1L, 2L, 3L) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + expect_equal(class(y[[1]]), "integer") + + x <- list(1, 2, 3) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + expect_equal(class(y[[1]]), "numeric") + + x <- list(TRUE, FALSE) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + expect_equal(class(y[[1]]), "logical") + + x <- list("a", "b", "c") + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + expect_equal(class(y[[1]]), "character") + + # Empty list + x <- list() + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) +}) + +test_that("SerDe of list of lists", { + x <- list(list(1L, 2L, 3L), list(1, 2, 3), + list(TRUE, FALSE), list("a", "b", "c")) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) + + # List of empty lists + x <- list(list(), list()) + y <- callJStatic("SparkRHandler", "echo", x) + expect_equal(x, y) +}) diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index 7e3b5fc403b25..0c3b0d1f4be20 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -94,7 +94,7 @@ if (isEmpty != 0) { } else if (deserializer == "string") { data <- as.list(readLines(inputCon)) } else if (deserializer == "row") { - data <- SparkR:::readDeserializeRows(inputCon) + data <- SparkR:::readMultipleObjects(inputCon) } # Timing reading input data for execution inputElap <- elapsedSecs() @@ -120,7 +120,7 @@ if (isEmpty != 0) { } else if (deserializer == "string") { data <- readLines(inputCon) } else if (deserializer == "row") { - data <- SparkR:::readDeserializeRows(inputCon) + data <- SparkR:::readMultipleObjects(inputCon) } # Timing reading input data for execution inputElap <- elapsedSecs() diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 6ce02e2ea336a..bb82f3285f1d9 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -53,6 +53,13 @@ private[r] class RBackendHandler(server: RBackend) if (objId == "SparkRHandler") { methodName match { + // This function is for test-purpose only + case "echo" => + val args = readArgs(numArgs, dis) + assert(numArgs == 1) + + writeInt(dos, 0) + writeObject(dos, args(0)) case "stopBackend" => writeInt(dos, 0) writeType(dos, "void") diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index dbbbcf40c1e96..26ad4f1d4697e 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -149,6 +149,10 @@ private[spark] object SerDe { case 'b' => readBooleanArr(dis) case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) case 'r' => readBytesArr(dis) + case 'l' => { + val len = readInt(dis) + (0 until len).map(_ => readList(dis)).toArray + } case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") } } @@ -200,6 +204,9 @@ private[spark] object SerDe { case "date" => dos.writeByte('D') case "time" => dos.writeByte('t') case "raw" => dos.writeByte('r') + // Array of primitive types + case "array" => dos.writeByte('a') + // Array of objects case "list" => dos.writeByte('l') case "jobj" => dos.writeByte('j') case _ => throw new IllegalArgumentException(s"Invalid type $typeStr") @@ -211,26 +218,35 @@ private[spark] object SerDe { writeType(dos, "void") } else { value.getClass.getName match { + case "java.lang.Character" => + writeType(dos, "character") + writeString(dos, value.asInstanceOf[Character].toString) case "java.lang.String" => writeType(dos, "character") writeString(dos, value.asInstanceOf[String]) - case "long" | "java.lang.Long" => + case "java.lang.Long" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Long].toDouble) - case "float" | "java.lang.Float" => + case "java.lang.Float" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Float].toDouble) - case "decimal" | "java.math.BigDecimal" => + case "java.math.BigDecimal" => writeType(dos, "double") val javaDecimal = value.asInstanceOf[java.math.BigDecimal] writeDouble(dos, scala.math.BigDecimal(javaDecimal).toDouble) - case "double" | "java.lang.Double" => + case "java.lang.Double" => writeType(dos, "double") writeDouble(dos, value.asInstanceOf[Double]) - case "int" | "java.lang.Integer" => + case "java.lang.Byte" => + writeType(dos, "integer") + writeInt(dos, value.asInstanceOf[Byte].toInt) + case "java.lang.Short" => + writeType(dos, "integer") + writeInt(dos, value.asInstanceOf[Short].toInt) + case "java.lang.Integer" => writeType(dos, "integer") writeInt(dos, value.asInstanceOf[Int]) - case "boolean" | "java.lang.Boolean" => + case "java.lang.Boolean" => writeType(dos, "logical") writeBoolean(dos, value.asInstanceOf[Boolean]) case "java.sql.Date" => @@ -242,43 +258,48 @@ private[spark] object SerDe { case "java.sql.Timestamp" => writeType(dos, "time") writeTime(dos, value.asInstanceOf[Timestamp]) + + // Handle arrays + + // Array of primitive types + + // Special handling for byte array case "[B" => writeType(dos, "raw") writeBytes(dos, value.asInstanceOf[Array[Byte]]) - // TODO: Types not handled right now include - // byte, char, short, float - // Handle arrays - case "[Ljava.lang.String;" => - writeType(dos, "list") - writeStringArr(dos, value.asInstanceOf[Array[String]]) + case "[C" => + writeType(dos, "array") + writeStringArr(dos, value.asInstanceOf[Array[Char]].map(_.toString)) + case "[S" => + writeType(dos, "array") + writeIntArr(dos, value.asInstanceOf[Array[Short]].map(_.toInt)) case "[I" => - writeType(dos, "list") + writeType(dos, "array") writeIntArr(dos, value.asInstanceOf[Array[Int]]) case "[J" => - writeType(dos, "list") + writeType(dos, "array") writeDoubleArr(dos, value.asInstanceOf[Array[Long]].map(_.toDouble)) + case "[F" => + writeType(dos, "array") + writeDoubleArr(dos, value.asInstanceOf[Array[Float]].map(_.toDouble)) case "[D" => - writeType(dos, "list") + writeType(dos, "array") writeDoubleArr(dos, value.asInstanceOf[Array[Double]]) case "[Z" => - writeType(dos, "list") + writeType(dos, "array") writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]]) - case "[[B" => + + // Array of objects, null objects use "void" type + case c if c.startsWith("[") => writeType(dos, "list") - writeBytesArr(dos, value.asInstanceOf[Array[Array[Byte]]]) - case otherName => - // Handle array of objects - if (otherName.startsWith("[L")) { - val objArr = value.asInstanceOf[Array[Object]] - writeType(dos, "list") - writeType(dos, "jobj") - dos.writeInt(objArr.length) - objArr.foreach(o => writeJObj(dos, o)) - } else { - writeType(dos, "jobj") - writeJObj(dos, value) - } + val array = value.asInstanceOf[Array[Object]] + writeInt(dos, array.length) + array.foreach(elem => writeObject(dos, elem)) + + case _ => + writeType(dos, "jobj") + writeJObj(dos, value) } } } @@ -350,11 +371,6 @@ private[spark] object SerDe { value.foreach(v => writeString(out, v)) } - def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { - writeType(out, "raw") - out.writeInt(value.length) - value.foreach(v => writeBytes(out, v)) - } } private[r] object SerializationFormats { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 92861ab038f19..7f3defec3d42e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -98,27 +98,17 @@ private[r] object SQLUtils { val bos = new ByteArrayOutputStream() val dos = new DataOutputStream(bos) - SerDe.writeInt(dos, row.length) - (0 until row.length).map { idx => - val obj: Object = row(idx).asInstanceOf[Object] - SerDe.writeObject(dos, obj) - } + val cols = (0 until row.length).map(row(_).asInstanceOf[Object]).toArray + SerDe.writeObject(dos, cols) bos.toByteArray() } - def dfToCols(df: DataFrame): Array[Array[Byte]] = { + def dfToCols(df: DataFrame): Array[Array[Any]] = { // localDF is Array[Row] val localDF = df.collect() val numCols = df.columns.length - // dfCols is Array[Array[Any]] - val dfCols = convertRowsToColumns(localDF, numCols) - - dfCols.map { col => - colToRBytes(col) - } - } - def convertRowsToColumns(localDF: Array[Row], numCols: Int): Array[Array[Any]] = { + // result is Array[Array[Any]] (0 until numCols).map { colIdx => localDF.map { row => row(colIdx) @@ -126,20 +116,6 @@ private[r] object SQLUtils { }.toArray } - def colToRBytes(col: Array[Any]): Array[Byte] = { - val numRows = col.length - val bos = new ByteArrayOutputStream() - val dos = new DataOutputStream(bos) - - SerDe.writeInt(dos, numRows) - - col.map { item => - val obj: Object = item.asInstanceOf[Object] - SerDe.writeObject(dos, obj) - } - bos.toByteArray() - } - def saveMode(mode: String): SaveMode = { mode match { case "append" => SaveMode.Append From c0e9ff1588b4d9313cc6ec6e00e5c7663eb67910 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Tue, 25 Aug 2015 13:21:05 -0700 Subject: [PATCH 077/802] [SPARK-9800] Adds docs for GradientDescent$.runMiniBatchSGD alias * Adds doc for alias of runMIniBatchSGD documenting default value for convergeTol * Cleans up a note in code Author: Feynman Liang Closes #8425 from feynmanliang/SPARK-9800. --- .../apache/spark/mllib/optimization/GradientDescent.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 8f0d1e4aa010a..3b663b5defb03 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -235,7 +235,7 @@ object GradientDescent extends Logging { if (miniBatchSize > 0) { /** - * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration + * lossSum is computed using the weights from the previous iteration * and regVal is the regularization value computed in the previous iteration as well. */ stochasticLossHistory.append(lossSum / miniBatchSize + regVal) @@ -264,6 +264,9 @@ object GradientDescent extends Logging { } + /** + * Alias of [[runMiniBatchSGD]] with convergenceTol set to default value of 0.001. + */ def runMiniBatchSGD( data: RDD[(Double, Vector)], gradient: Gradient, From c619c7552f22d28cfa321ce671fc9ca854dd655f Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 25 Aug 2015 13:22:38 -0700 Subject: [PATCH 078/802] [SPARK-10237] [MLLIB] update since versions in mllib.fpm Same as #8421 but for `mllib.fpm`. cc feynmanliang Author: Xiangrui Meng Closes #8429 from mengxr/SPARK-10237. --- .../spark/mllib/fpm/AssociationRules.scala | 7 ++++-- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 9 ++++++-- .../apache/spark/mllib/fpm/PrefixSpan.scala | 23 ++++++++++++++++--- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala index ba3b447a83398..95c688c86a7e4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala @@ -82,12 +82,15 @@ class AssociationRules private[fpm] ( }.filter(_.confidence >= minConfidence) } + /** Java-friendly version of [[run]]. */ + @Since("1.5.0") def run[Item](freqItemsets: JavaRDD[FreqItemset[Item]]): JavaRDD[Rule[Item]] = { val tag = fakeClassTag[Item] run(freqItemsets.rdd)(tag) } } +@Since("1.5.0") object AssociationRules { /** @@ -104,8 +107,8 @@ object AssociationRules { @Since("1.5.0") @Experimental class Rule[Item] private[fpm] ( - val antecedent: Array[Item], - val consequent: Array[Item], + @Since("1.5.0") val antecedent: Array[Item], + @Since("1.5.0") val consequent: Array[Item], freqUnion: Double, freqAntecedent: Double) extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index e37f806271680..aea5c4f8a8a7d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -42,7 +42,8 @@ import org.apache.spark.storage.StorageLevel */ @Since("1.3.0") @Experimental -class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable { +class FPGrowthModel[Item: ClassTag] @Since("1.3.0") ( + @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable { /** * Generates association rules for the [[Item]]s in [[freqItemsets]]. * @param confidence minimal confidence of the rules produced @@ -126,6 +127,8 @@ class FPGrowth private ( new FPGrowthModel(freqItemsets) } + /** Java-friendly version of [[run]]. */ + @Since("1.3.0") def run[Item, Basket <: JavaIterable[Item]](data: JavaRDD[Basket]): FPGrowthModel[Item] = { implicit val tag = fakeClassTag[Item] run(data.rdd.map(_.asScala.toArray)) @@ -226,7 +229,9 @@ object FPGrowth { * */ @Since("1.3.0") - class FreqItemset[Item](val items: Array[Item], val freq: Long) extends Serializable { + class FreqItemset[Item] @Since("1.3.0") ( + @Since("1.3.0") val items: Array[Item], + @Since("1.3.0") val freq: Long) extends Serializable { /** * Returns items in a Java List. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index dc4ae1d0b69ed..97916daa2e9ad 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import scala.reflect.ClassTag import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.rdd.RDD @@ -51,6 +51,7 @@ import org.apache.spark.storage.StorageLevel * (Wikipedia)]] */ @Experimental +@Since("1.5.0") class PrefixSpan private ( private var minSupport: Double, private var maxPatternLength: Int, @@ -61,17 +62,20 @@ class PrefixSpan private ( * Constructs a default instance with default parameters * {minSupport: `0.1`, maxPatternLength: `10`, maxLocalProjDBSize: `32000000L`}. */ + @Since("1.5.0") def this() = this(0.1, 10, 32000000L) /** * Get the minimal support (i.e. the frequency of occurrence before a pattern is considered * frequent). */ + @Since("1.5.0") def getMinSupport: Double = minSupport /** * Sets the minimal support level (default: `0.1`). */ + @Since("1.5.0") def setMinSupport(minSupport: Double): this.type = { require(minSupport >= 0 && minSupport <= 1, s"The minimum support value must be in [0, 1], but got $minSupport.") @@ -82,11 +86,13 @@ class PrefixSpan private ( /** * Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider. */ + @Since("1.5.0") def getMaxPatternLength: Int = maxPatternLength /** * Sets maximal pattern length (default: `10`). */ + @Since("1.5.0") def setMaxPatternLength(maxPatternLength: Int): this.type = { // TODO: support unbounded pattern length when maxPatternLength = 0 require(maxPatternLength >= 1, @@ -98,12 +104,14 @@ class PrefixSpan private ( /** * Gets the maximum number of items allowed in a projected database before local processing. */ + @Since("1.5.0") def getMaxLocalProjDBSize: Long = maxLocalProjDBSize /** * Sets the maximum number of items (including delimiters used in the internal storage format) * allowed in a projected database before local processing (default: `32000000L`). */ + @Since("1.5.0") def setMaxLocalProjDBSize(maxLocalProjDBSize: Long): this.type = { require(maxLocalProjDBSize >= 0L, s"The maximum local projected database size must be nonnegative, but got $maxLocalProjDBSize") @@ -116,6 +124,7 @@ class PrefixSpan private ( * @param data sequences of itemsets. * @return a [[PrefixSpanModel]] that contains the frequent patterns */ + @Since("1.5.0") def run[Item: ClassTag](data: RDD[Array[Array[Item]]]): PrefixSpanModel[Item] = { if (data.getStorageLevel == StorageLevel.NONE) { logWarning("Input data is not cached.") @@ -202,6 +211,7 @@ class PrefixSpan private ( * @tparam Sequence sequence type, which is an Iterable of Itemsets * @return a [[PrefixSpanModel]] that contains the frequent sequential patterns */ + @Since("1.5.0") def run[Item, Itemset <: jl.Iterable[Item], Sequence <: jl.Iterable[Itemset]]( data: JavaRDD[Sequence]): PrefixSpanModel[Item] = { implicit val tag = fakeClassTag[Item] @@ -211,6 +221,7 @@ class PrefixSpan private ( } @Experimental +@Since("1.5.0") object PrefixSpan extends Logging { /** @@ -535,10 +546,14 @@ object PrefixSpan extends Logging { * @param freq frequency * @tparam Item item type */ - class FreqSequence[Item](val sequence: Array[Array[Item]], val freq: Long) extends Serializable { + @Since("1.5.0") + class FreqSequence[Item] @Since("1.5.0") ( + @Since("1.5.0") val sequence: Array[Array[Item]], + @Since("1.5.0") val freq: Long) extends Serializable { /** * Returns sequence as a Java List of lists for Java users. */ + @Since("1.5.0") def javaSequence: ju.List[ju.List[Item]] = sequence.map(_.toList.asJava).toList.asJava } } @@ -548,5 +563,7 @@ object PrefixSpan extends Logging { * @param freqSequences frequent sequences * @tparam Item item type */ -class PrefixSpanModel[Item](val freqSequences: RDD[PrefixSpan.FreqSequence[Item]]) +@Since("1.5.0") +class PrefixSpanModel[Item] @Since("1.5.0") ( + @Since("1.5.0") val freqSequences: RDD[PrefixSpan.FreqSequence[Item]]) extends Serializable From 9205907876cf65695e56c2a94bedd83df3675c03 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Tue, 25 Aug 2015 13:23:15 -0700 Subject: [PATCH 079/802] [SPARK-9797] [MLLIB] [DOC] StreamingLinearRegressionWithSGD.setConvergenceTol default value Adds default convergence tolerance (0.001, set in `GradientDescent.convergenceTol`) to `setConvergenceTol`'s scaladoc Author: Feynman Liang Closes #8424 from feynmanliang/SPARK-9797. --- .../mllib/regression/StreamingLinearRegressionWithSGD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index 537a05274eec2..26654e4a06838 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -93,7 +93,7 @@ class StreamingLinearRegressionWithSGD private[mllib] ( } /** - * Set the convergence tolerance. + * Set the convergence tolerance. Default: 0.001. */ def setConvergenceTol(tolerance: Double): this.type = { this.algorithm.optimizer.setConvergenceTol(tolerance) From 00ae4be97f7b205432db2967ba6d506286ef2ca6 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 25 Aug 2015 14:11:38 -0700 Subject: [PATCH 080/802] [SPARK-10239] [SPARK-10244] [MLLIB] update since versions in mllib.pmml and mllib.util Same as #8421 but for `mllib.pmml` and `mllib.util`. cc dbtsai Author: Xiangrui Meng Closes #8430 from mengxr/SPARK-10239 and squashes the following commits: a189acf [Xiangrui Meng] update since versions in mllib.pmml and mllib.util --- .../org/apache/spark/mllib/pmml/PMMLExportable.scala | 7 ++++++- .../org/apache/spark/mllib/util/DataValidators.scala | 7 +++++-- .../apache/spark/mllib/util/KMeansDataGenerator.scala | 5 ++++- .../apache/spark/mllib/util/LinearDataGenerator.scala | 10 ++++++++-- .../mllib/util/LogisticRegressionDataGenerator.scala | 5 ++++- .../org/apache/spark/mllib/util/MFDataGenerator.scala | 4 +++- .../scala/org/apache/spark/mllib/util/MLUtils.scala | 2 ++ .../org/apache/spark/mllib/util/SVMDataGenerator.scala | 6 ++++-- .../org/apache/spark/mllib/util/modelSaveLoad.scala | 6 +++++- 9 files changed, 41 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala index 5e882d4ebb10b..274ac7c99553b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala @@ -23,7 +23,7 @@ import javax.xml.transform.stream.StreamResult import org.jpmml.model.JAXBUtil import org.apache.spark.SparkContext -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory /** @@ -33,6 +33,7 @@ import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory * developed by the Data Mining Group (www.dmg.org). */ @DeveloperApi +@Since("1.4.0") trait PMMLExportable { /** @@ -48,6 +49,7 @@ trait PMMLExportable { * Export the model to a local file in PMML format */ @Experimental + @Since("1.4.0") def toPMML(localPath: String): Unit = { toPMML(new StreamResult(new File(localPath))) } @@ -57,6 +59,7 @@ trait PMMLExportable { * Export the model to a directory on a distributed file system in PMML format */ @Experimental + @Since("1.4.0") def toPMML(sc: SparkContext, path: String): Unit = { val pmml = toPMML() sc.parallelize(Array(pmml), 1).saveAsTextFile(path) @@ -67,6 +70,7 @@ trait PMMLExportable { * Export the model to the OutputStream in PMML format */ @Experimental + @Since("1.4.0") def toPMML(outputStream: OutputStream): Unit = { toPMML(new StreamResult(outputStream)) } @@ -76,6 +80,7 @@ trait PMMLExportable { * Export the model to a String in PMML format */ @Experimental + @Since("1.4.0") def toPMML(): String = { val writer = new StringWriter toPMML(new StreamResult(writer)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala index be335a1aca58a..dffe6e78939e8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala @@ -17,16 +17,17 @@ package org.apache.spark.mllib.util -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Logging -import org.apache.spark.rdd.RDD +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: * A collection of methods used to validate data before applying ML algorithms. */ @DeveloperApi +@Since("0.8.0") object DataValidators extends Logging { /** @@ -34,6 +35,7 @@ object DataValidators extends Logging { * * @return True if labels are all zero or one, false otherwise. */ + @Since("1.0.0") val binaryLabelValidator: RDD[LabeledPoint] => Boolean = { data => val numInvalid = data.filter(x => x.label != 1.0 && x.label != 0.0).count() if (numInvalid != 0) { @@ -48,6 +50,7 @@ object DataValidators extends Logging { * * @return True if labels are all in the range of {0, 1, ..., k-1}, false otherwise. */ + @Since("1.3.0") def multiLabelValidator(k: Int): RDD[LabeledPoint] => Boolean = { data => val numInvalid = data.filter(x => x.label - x.label.toInt != 0.0 || x.label < 0 || x.label > k - 1).count() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala index e6bcff48b022c..00fd1606a369c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala @@ -19,8 +19,8 @@ package org.apache.spark.mllib.util import scala.util.Random -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.rdd.RDD /** @@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD * cluster with scale 1 around each center. */ @DeveloperApi +@Since("0.8.0") object KMeansDataGenerator { /** @@ -42,6 +43,7 @@ object KMeansDataGenerator { * @param r Scaling factor for the distribution of the initial centers * @param numPartitions Number of partitions of the generated RDD; default 2 */ + @Since("0.8.0") def generateKMeansRDD( sc: SparkContext, numPoints: Int, @@ -62,6 +64,7 @@ object KMeansDataGenerator { } } + @Since("0.8.0") def main(args: Array[String]) { if (args.length < 6) { // scalastyle:off println diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index 7a1c7796065ee..d0ba454f379a9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -22,11 +22,11 @@ import scala.util.Random import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: @@ -35,6 +35,7 @@ import org.apache.spark.mllib.regression.LabeledPoint * response variable `Y`. */ @DeveloperApi +@Since("0.8.0") object LinearDataGenerator { /** @@ -46,6 +47,7 @@ object LinearDataGenerator { * @param seed Random seed * @return Java List of input. */ + @Since("0.8.0") def generateLinearInputAsList( intercept: Double, weights: Array[Double], @@ -68,6 +70,7 @@ object LinearDataGenerator { * @param eps Epsilon scaling factor. * @return Seq of input. */ + @Since("0.8.0") def generateLinearInput( intercept: Double, weights: Array[Double], @@ -92,6 +95,7 @@ object LinearDataGenerator { * @param eps Epsilon scaling factor. * @return Seq of input. */ + @Since("0.8.0") def generateLinearInput( intercept: Double, weights: Array[Double], @@ -132,6 +136,7 @@ object LinearDataGenerator { * * @return RDD of LabeledPoint containing sample data. */ + @Since("0.8.0") def generateLinearRDD( sc: SparkContext, nexamples: Int, @@ -151,6 +156,7 @@ object LinearDataGenerator { data } + @Since("0.8.0") def main(args: Array[String]) { if (args.length < 2) { // scalastyle:off println diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala index c09cbe69bb971..33477ee20ebbd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.util import scala.util.Random -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{Since, DeveloperApi} import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint @@ -31,6 +31,7 @@ import org.apache.spark.mllib.linalg.Vectors * with probability `probOne` and scales features for positive examples by `eps`. */ @DeveloperApi +@Since("0.8.0") object LogisticRegressionDataGenerator { /** @@ -43,6 +44,7 @@ object LogisticRegressionDataGenerator { * @param nparts Number of partitions of the generated RDD. Default value is 2. * @param probOne Probability that a label is 1 (and not 0). Default value is 0.5. */ + @Since("0.8.0") def generateLogisticRDD( sc: SparkContext, nexamples: Int, @@ -62,6 +64,7 @@ object LogisticRegressionDataGenerator { data } + @Since("0.8.0") def main(args: Array[String]) { if (args.length != 5) { // scalastyle:off println diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala index 16f430599a515..906bd30563bd0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala @@ -23,7 +23,7 @@ import scala.language.postfixOps import scala.util.Random import org.apache.spark.SparkContext -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{Since, DeveloperApi} import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix} import org.apache.spark.rdd.RDD @@ -52,7 +52,9 @@ import org.apache.spark.rdd.RDD * testSampFact (Double) Percentage of training data to use as test data. */ @DeveloperApi +@Since("0.8.0") object MFDataGenerator { + @Since("0.8.0") def main(args: Array[String]) { if (args.length < 2) { // scalastyle:off println diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 4940974bf4f41..81c2f0ce6e12c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -36,6 +36,7 @@ import org.apache.spark.streaming.dstream.DStream /** * Helper methods to load, save and pre-process data used in ML Lib. */ +@Since("0.8.0") object MLUtils { private[mllib] lazy val EPSILON = { @@ -168,6 +169,7 @@ object MLUtils { * * @see [[org.apache.spark.mllib.util.MLUtils#loadLibSVMFile]] */ + @Since("1.0.0") def saveAsLibSVMFile(data: RDD[LabeledPoint], dir: String) { // TODO: allow to specify label precision and feature precision. val dataStr = data.map { case LabeledPoint(label, features) => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala index ad20b7694a779..cde5979396178 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala @@ -21,11 +21,11 @@ import scala.util.Random import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: @@ -33,8 +33,10 @@ import org.apache.spark.mllib.regression.LabeledPoint * for the features and adds Gaussian noise with weight 0.1 to generate labels. */ @DeveloperApi +@Since("0.8.0") object SVMDataGenerator { + @Since("0.8.0") def main(args: Array[String]) { if (args.length < 2) { // scalastyle:off println diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala index 30d642c754b7c..4d71d534a0774 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala @@ -24,7 +24,7 @@ import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.types.{DataType, StructField, StructType} @@ -35,6 +35,7 @@ import org.apache.spark.sql.types.{DataType, StructField, StructType} * This should be inherited by the class which implements model instances. */ @DeveloperApi +@Since("1.3.0") trait Saveable { /** @@ -50,6 +51,7 @@ trait Saveable { * @param path Path specifying the directory in which to save this model. * If the directory already exists, this method throws an exception. */ + @Since("1.3.0") def save(sc: SparkContext, path: String): Unit /** Current version of model save/load format. */ @@ -64,6 +66,7 @@ trait Saveable { * This should be inherited by an object paired with the model class. */ @DeveloperApi +@Since("1.3.0") trait Loader[M <: Saveable] { /** @@ -75,6 +78,7 @@ trait Loader[M <: Saveable] { * @param path Path specifying the directory to which the model was saved. * @return Model instance */ + @Since("1.3.0") def load(sc: SparkContext, path: String): M } From ec89bd840a6862751999d612f586a962cae63f6d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 25 Aug 2015 14:55:34 -0700 Subject: [PATCH 081/802] [SPARK-10245] [SQL] Fix decimal literals with precision < scale In BigDecimal or java.math.BigDecimal, the precision could be smaller than scale, for example, BigDecimal("0.001") has precision = 1 and scale = 3. But DecimalType require that the precision should be larger than scale, so we should use the maximum of precision and scale when inferring the schema from decimal literal. Author: Davies Liu Closes #8428 from davies/smaller_decimal. --- .../spark/sql/catalyst/expressions/literals.scala | 7 ++++--- .../catalyst/expressions/LiteralExpressionSuite.scala | 8 +++++--- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 10 ++++++++++ 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 34bad23802ba4..8c0c5d5b1e31e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -36,9 +36,10 @@ object Literal { case s: Short => Literal(s, ShortType) case s: String => Literal(UTF8String.fromString(s), StringType) case b: Boolean => Literal(b, BooleanType) - case d: BigDecimal => Literal(Decimal(d), DecimalType(d.precision, d.scale)) - case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType(d.precision(), d.scale())) - case d: Decimal => Literal(d, DecimalType(d.precision, d.scale)) + case d: BigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale)) + case d: java.math.BigDecimal => + Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale())) + case d: Decimal => Literal(d, DecimalType(Math.max(d.precision, d.scale), d.scale)) case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType) case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) case a: Array[Byte] => Literal(a, BinaryType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index f6404d21611e5..015eb1897fb8c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -83,12 +83,14 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { } test("decimal") { - List(0.0, 1.2, 1.1111, 5).foreach { d => + List(-0.0001, 0.0, 0.001, 1.2, 1.1111, 5).foreach { d => checkEvaluation(Literal(Decimal(d)), Decimal(d)) checkEvaluation(Literal(Decimal(d.toInt)), Decimal(d.toInt)) checkEvaluation(Literal(Decimal(d.toLong)), Decimal(d.toLong)) - checkEvaluation(Literal(Decimal((d * 1000L).toLong, 10, 1)), - Decimal((d * 1000L).toLong, 10, 1)) + checkEvaluation(Literal(Decimal((d * 1000L).toLong, 10, 3)), + Decimal((d * 1000L).toLong, 10, 3)) + checkEvaluation(Literal(BigDecimal(d.toString)), Decimal(d)) + checkEvaluation(Literal(new java.math.BigDecimal(d.toString)), Decimal(d)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index dcb4e83710982..aa07665c6b705 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1627,6 +1627,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(null)) } + test("precision smaller than scale") { + checkAnswer(sql("select 10.00"), Row(BigDecimal("10.00"))) + checkAnswer(sql("select 1.00"), Row(BigDecimal("1.00"))) + checkAnswer(sql("select 0.10"), Row(BigDecimal("0.10"))) + checkAnswer(sql("select 0.01"), Row(BigDecimal("0.01"))) + checkAnswer(sql("select 0.001"), Row(BigDecimal("0.001"))) + checkAnswer(sql("select -0.01"), Row(BigDecimal("-0.01"))) + checkAnswer(sql("select -0.001"), Row(BigDecimal("-0.001"))) + } + test("external sorting updates peak execution memory") { withSQLConf((SQLConf.EXTERNAL_SORT.key, "true")) { val sc = sqlContext.sparkContext From 7467b52ed07f174d93dfc4cb544dc4b69a2c2826 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 25 Aug 2015 15:19:41 -0700 Subject: [PATCH 082/802] [SPARK-10215] [SQL] Fix precision of division (follow the rule in Hive) Follow the rule in Hive for decimal division. see https://github.com/apache/hive/blob/ac755ebe26361a4647d53db2a28500f71697b276/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPDivide.java#L113 cc chenghao-intel Author: Davies Liu Closes #8415 from davies/decimal_div2. --- .../catalyst/analysis/HiveTypeCoercion.scala | 10 ++++++-- .../sql/catalyst/analysis/AnalysisSuite.scala | 9 +++---- .../analysis/DecimalPrecisionSuite.scala | 8 +++--- .../org/apache/spark/sql/SQLQuerySuite.scala | 25 +++++++++++++++++-- 4 files changed, 39 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index a1aa2a2b2c680..87c11abbad490 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -396,8 +396,14 @@ object HiveTypeCoercion { resultType) case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(p1 - s1 + s2 + max(6, s1 + p2 + 1), - max(6, s1 + p2 + 1)) + var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) + var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1)) + val diff = (intDig + decDig) - DecimalType.MAX_SCALE + if (diff > 0) { + decDig -= diff / 2 + 1 + intDig = DecimalType.MAX_SCALE - decDig + } + val resultType = DecimalType.bounded(intDig + decDig, decDig) val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 1e0cc81dae974..820b336aac759 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.SimpleCatalystConf -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ class AnalysisSuite extends AnalysisTest { - import TestRelations._ + import org.apache.spark.sql.catalyst.analysis.TestRelations._ test("union project *") { val plan = (1 to 100) @@ -96,7 +95,7 @@ class AnalysisSuite extends AnalysisTest { assert(pl(1).dataType == DoubleType) assert(pl(2).dataType == DoubleType) // StringType will be promoted into Decimal(38, 18) - assert(pl(3).dataType == DecimalType(38, 29)) + assert(pl(3).dataType == DecimalType(38, 22)) assert(pl(4).dataType == DoubleType) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index fc11627da6fd1..b4ad618c23e39 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -136,10 +136,10 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { checkType(Multiply(i, u), DecimalType(38, 18)) checkType(Multiply(u, u), DecimalType(38, 36)) - checkType(Divide(u, d1), DecimalType(38, 21)) - checkType(Divide(u, d2), DecimalType(38, 24)) - checkType(Divide(u, i), DecimalType(38, 29)) - checkType(Divide(u, u), DecimalType(38, 38)) + checkType(Divide(u, d1), DecimalType(38, 18)) + checkType(Divide(u, d2), DecimalType(38, 19)) + checkType(Divide(u, i), DecimalType(38, 23)) + checkType(Divide(u, u), DecimalType(38, 18)) checkType(Remainder(d1, u), DecimalType(19, 18)) checkType(Remainder(d2, u), DecimalType(21, 18)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index aa07665c6b705..9e172b2c264cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1622,9 +1622,30 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333"))) checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333"))) checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"), - Row(BigDecimal("3.4333333333333333333333333333333333333", new MathContext(38)))) + Row(BigDecimal("3.433333333333333333333333333", new MathContext(38)))) checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"), - Row(null)) + Row(BigDecimal("3.4333333333333333333333333333", new MathContext(38)))) + } + + test("SPARK-10215 Div of Decimal returns null") { + val d = Decimal(1.12321) + val df = Seq((d, 1)).toDF("a", "b") + + checkAnswer( + df.selectExpr("b * a / b"), + Seq(Row(d.toBigDecimal))) + checkAnswer( + df.selectExpr("b * a / b / b"), + Seq(Row(d.toBigDecimal))) + checkAnswer( + df.selectExpr("b * a + b"), + Seq(Row(BigDecimal(2.12321)))) + checkAnswer( + df.selectExpr("b * a - b"), + Seq(Row(BigDecimal(0.12321)))) + checkAnswer( + df.selectExpr("b * a * b"), + Seq(Row(d.toBigDecimal))) } test("precision smaller than scale") { From 125205cdb35530cdb4a8fff3e1ee49cf4a299583 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Tue, 25 Aug 2015 17:39:20 -0700 Subject: [PATCH 083/802] [SPARK-9888] [MLLIB] User guide for new LDA features * Adds two new sections to LDA's user guide; one for each optimizer/model * Documents new features added to LDA (e.g. topXXXperXXX, asymmetric priors, hyperpam optimization) * Cleans up a TODO and sets a default parameter in LDA code jkbradley hhbyyh Author: Feynman Liang Closes #8254 from feynmanliang/SPARK-9888. --- docs/mllib-clustering.md | 135 +++++++++++++++--- .../spark/mllib/clustering/LDAModel.scala | 1 - .../spark/mllib/clustering/LDASuite.scala | 1 + 3 files changed, 117 insertions(+), 20 deletions(-) diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index fd9ab258e196d..3fb35d3c50b06 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -438,28 +438,125 @@ sameModel = PowerIterationClusteringModel.load(sc, "myModelPath") is a topic model which infers topics from a collection of text documents. LDA can be thought of as a clustering algorithm as follows: -* Topics correspond to cluster centers, and documents correspond to examples (rows) in a dataset. -* Topics and documents both exist in a feature space, where feature vectors are vectors of word counts. -* Rather than estimating a clustering using a traditional distance, LDA uses a function based - on a statistical model of how text documents are generated. - -LDA takes in a collection of documents as vectors of word counts. -It supports different inference algorithms via `setOptimizer` function. EMLDAOptimizer learns clustering using [expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) -on the likelihood function and yields comprehensive results, while OnlineLDAOptimizer uses iterative mini-batch sampling for [online variational inference](https://www.cs.princeton.edu/~blei/papers/HoffmanBleiBach2010b.pdf) and is generally memory friendly. After fitting on the documents, LDA provides: - -* Topics: Inferred topics, each of which is a probability distribution over terms (words). -* Topic distributions for documents: For each non empty document in the training set, LDA gives a probability distribution over topics. (EM only). Note that for empty documents, we don't create the topic distributions. (EM only) +* Topics correspond to cluster centers, and documents correspond to +examples (rows) in a dataset. +* Topics and documents both exist in a feature space, where feature +vectors are vectors of word counts (bag of words). +* Rather than estimating a clustering using a traditional distance, LDA +uses a function based on a statistical model of how text documents are +generated. + +LDA supports different inference algorithms via `setOptimizer` function. +`EMLDAOptimizer` learns clustering using +[expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) +on the likelihood function and yields comprehensive results, while +`OnlineLDAOptimizer` uses iterative mini-batch sampling for [online +variational +inference](https://www.cs.princeton.edu/~blei/papers/HoffmanBleiBach2010b.pdf) +and is generally memory friendly. -LDA takes the following parameters: +LDA takes in a collection of documents as vectors of word counts and the +following parameters (set using the builder pattern): * `k`: Number of topics (i.e., cluster centers) -* `maxIterations`: Limit on the number of iterations of EM used for learning -* `docConcentration`: Hyperparameter for prior over documents' distributions over topics. Currently must be > 1, where larger values encourage smoother inferred distributions. -* `topicConcentration`: Hyperparameter for prior over topics' distributions over terms (words). Currently must be > 1, where larger values encourage smoother inferred distributions. -* `checkpointInterval`: If using checkpointing (set in the Spark configuration), this parameter specifies the frequency with which checkpoints will be created. If `maxIterations` is large, using checkpointing can help reduce shuffle file sizes on disk and help with failure recovery. - -*Note*: LDA is a new feature with some missing functionality. In particular, it does not yet -support prediction on new documents, and it does not have a Python API. These will be added in the future. +* `optimizer`: Optimizer to use for learning the LDA model, either +`EMLDAOptimizer` or `OnlineLDAOptimizer` +* `docConcentration`: Dirichlet parameter for prior over documents' +distributions over topics. Larger values encourage smoother inferred +distributions. +* `topicConcentration`: Dirichlet parameter for prior over topics' +distributions over terms (words). Larger values encourage smoother +inferred distributions. +* `maxIterations`: Limit on the number of iterations. +* `checkpointInterval`: If using checkpointing (set in the Spark +configuration), this parameter specifies the frequency with which +checkpoints will be created. If `maxIterations` is large, using +checkpointing can help reduce shuffle file sizes on disk and help with +failure recovery. + + +All of MLlib's LDA models support: + +* `describeTopics`: Returns topics as arrays of most important terms and +term weights +* `topicsMatrix`: Returns a `vocabSize` by `k` matrix where each column +is a topic + +*Note*: LDA is still an experimental feature under active development. +As a result, certain features are only available in one of the two +optimizers / models generated by the optimizer. Currently, a distributed +model can be converted into a local model, but not vice-versa. + +The following discussion will describe each optimizer/model pair +separately. + +**Expectation Maximization** + +Implemented in +[`EMLDAOptimizer`](api/scala/index.html#org.apache.spark.mllib.clustering.EMLDAOptimizer) +and +[`DistributedLDAModel`](api/scala/index.html#org.apache.spark.mllib.clustering.DistributedLDAModel). + +For the parameters provided to `LDA`: + +* `docConcentration`: Only symmetric priors are supported, so all values +in the provided `k`-dimensional vector must be identical. All values +must also be $> 1.0$. Providing `Vector(-1)` results in default behavior +(uniform `k` dimensional vector with value $(50 / k) + 1$ +* `topicConcentration`: Only symmetric priors supported. Values must be +$> 1.0$. Providing `-1` results in defaulting to a value of $0.1 + 1$. +* `maxIterations`: The maximum number of EM iterations. + +`EMLDAOptimizer` produces a `DistributedLDAModel`, which stores not only +the inferred topics but also the full training corpus and topic +distributions for each document in the training corpus. A +`DistributedLDAModel` supports: + + * `topTopicsPerDocument`: The top topics and their weights for + each document in the training corpus + * `topDocumentsPerTopic`: The top documents for each topic and + the corresponding weight of the topic in the documents. + * `logPrior`: log probability of the estimated topics and + document-topic distributions given the hyperparameters + `docConcentration` and `topicConcentration` + * `logLikelihood`: log likelihood of the training corpus, given the + inferred topics and document-topic distributions + +**Online Variational Bayes** + +Implemented in +[`OnlineLDAOptimizer`](api/scala/org/apache/spark/mllib/clustering/OnlineLDAOptimizer.html) +and +[`LocalLDAModel`](api/scala/org/apache/spark/mllib/clustering/LocalLDAModel.html). + +For the parameters provided to `LDA`: + +* `docConcentration`: Asymmetric priors can be used by passing in a +vector with values equal to the Dirichlet parameter in each of the `k` +dimensions. Values should be $>= 0$. Providing `Vector(-1)` results in +default behavior (uniform `k` dimensional vector with value $(1.0 / k)$) +* `topicConcentration`: Only symmetric priors supported. Values must be +$>= 0$. Providing `-1` results in defaulting to a value of $(1.0 / k)$. +* `maxIterations`: Maximum number of minibatches to submit. + +In addition, `OnlineLDAOptimizer` accepts the following parameters: + +* `miniBatchFraction`: Fraction of corpus sampled and used at each +iteration +* `optimizeDocConcentration`: If set to true, performs maximum-likelihood +estimation of the hyperparameter `docConcentration` (aka `alpha`) +after each minibatch and sets the optimized `docConcentration` in the +returned `LocalLDAModel` +* `tau0` and `kappa`: Used for learning-rate decay, which is computed by +$(\tau_0 + iter)^{-\kappa}$ where $iter$ is the current number of iterations. + +`OnlineLDAOptimizer` produces a `LocalLDAModel`, which only stores the +inferred topics. A `LocalLDAModel` supports: + +* `logLikelihood(documents)`: Calculates a lower bound on the provided +`documents` given the inferred topics. +* `logPerplexity(documents)`: Calculates an upper bound on the +perplexity of the provided `documents` given the inferred topics. **Examples** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 667374a2bc418..432bbedc8d6f8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -435,7 +435,6 @@ object LocalLDAModel extends Loader[LocalLDAModel] { } val topicsMat = Matrices.fromBreeze(brzTopics) - // TODO: initialize with docConcentration, topicConcentration, and gammaShape after SPARK-9940 new LocalLDAModel(topicsMat, docConcentration, topicConcentration, gammaShape) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 746a76a7e5fa1..37fb69d68f6be 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -68,6 +68,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { // Train a model val lda = new LDA() lda.setK(k) + .setOptimizer(new EMLDAOptimizer) .setDocConcentration(topicSmoothing) .setTopicConcentration(termSmoothing) .setMaxIterations(5) From 8668ead2e7097b9591069599fbfccf67c53db659 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 25 Aug 2015 18:17:54 -0700 Subject: [PATCH 084/802] [SPARK-10233] [MLLIB] update since version in mllib.evaluation Same as #8421 but for `mllib.evaluation`. cc avulanov Author: Xiangrui Meng Closes #8423 from mengxr/SPARK-10233. --- .../evaluation/BinaryClassificationMetrics.scala | 8 ++++---- .../spark/mllib/evaluation/MulticlassMetrics.scala | 11 ++++++++++- .../spark/mllib/evaluation/MultilabelMetrics.scala | 12 +++++++++++- .../spark/mllib/evaluation/RegressionMetrics.scala | 3 ++- 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala index 76ae847921f44..508fe532b1306 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala @@ -42,11 +42,11 @@ import org.apache.spark.sql.DataFrame * be smaller as a result, meaning there may be an extra sample at * partition boundaries. */ -@Since("1.3.0") +@Since("1.0.0") @Experimental -class BinaryClassificationMetrics( - val scoreAndLabels: RDD[(Double, Double)], - val numBins: Int) extends Logging { +class BinaryClassificationMetrics @Since("1.3.0") ( + @Since("1.3.0") val scoreAndLabels: RDD[(Double, Double)], + @Since("1.3.0") val numBins: Int) extends Logging { require(numBins >= 0, "numBins must be nonnegative") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 02e89d921033c..00e837661dfc2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.DataFrame */ @Since("1.1.0") @Experimental -class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { +class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Double)]) { /** * An auxiliary constructor taking a DataFrame. @@ -140,6 +140,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns precision */ + @Since("1.1.0") lazy val precision: Double = tpByClass.values.sum.toDouble / labelCount /** @@ -148,23 +149,27 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * because sum of all false positives is equal to sum * of all false negatives) */ + @Since("1.1.0") lazy val recall: Double = precision /** * Returns f-measure * (equals to precision and recall because precision equals recall) */ + @Since("1.1.0") lazy val fMeasure: Double = precision /** * Returns weighted true positive rate * (equals to precision, recall and f-measure) */ + @Since("1.1.0") lazy val weightedTruePositiveRate: Double = weightedRecall /** * Returns weighted false positive rate */ + @Since("1.1.0") lazy val weightedFalsePositiveRate: Double = labelCountByClass.map { case (category, count) => falsePositiveRate(category) * count.toDouble / labelCount }.sum @@ -173,6 +178,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { * Returns weighted averaged recall * (equals to precision, recall and f-measure) */ + @Since("1.1.0") lazy val weightedRecall: Double = labelCountByClass.map { case (category, count) => recall(category) * count.toDouble / labelCount }.sum @@ -180,6 +186,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns weighted averaged precision */ + @Since("1.1.0") lazy val weightedPrecision: Double = labelCountByClass.map { case (category, count) => precision(category) * count.toDouble / labelCount }.sum @@ -196,6 +203,7 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns weighted averaged f1-measure */ + @Since("1.1.0") lazy val weightedFMeasure: Double = labelCountByClass.map { case (category, count) => fMeasure(category, 1.0) * count.toDouble / labelCount }.sum @@ -203,5 +211,6 @@ class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { /** * Returns the sequence of labels in ascending order */ + @Since("1.1.0") lazy val labels: Array[Double] = tpByClass.keys.toArray.sorted } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala index a0a8d9c56847b..c100b3c9ec14a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.DataFrame * both are non-null Arrays, each with unique elements. */ @Since("1.2.0") -class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) { +class MultilabelMetrics @Since("1.2.0") (predictionAndLabels: RDD[(Array[Double], Array[Double])]) { /** * An auxiliary constructor taking a DataFrame. @@ -46,6 +46,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns subset accuracy * (for equal sets of labels) */ + @Since("1.2.0") lazy val subsetAccuracy: Double = predictionAndLabels.filter { case (predictions, labels) => predictions.deep == labels.deep }.count().toDouble / numDocs @@ -53,6 +54,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns accuracy */ + @Since("1.2.0") lazy val accuracy: Double = predictionAndLabels.map { case (predictions, labels) => labels.intersect(predictions).size.toDouble / (labels.size + predictions.size - labels.intersect(predictions).size)}.sum / numDocs @@ -61,6 +63,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns Hamming-loss */ + @Since("1.2.0") lazy val hammingLoss: Double = predictionAndLabels.map { case (predictions, labels) => labels.size + predictions.size - 2 * labels.intersect(predictions).size }.sum / (numDocs * numLabels) @@ -68,6 +71,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns document-based precision averaged by the number of documents */ + @Since("1.2.0") lazy val precision: Double = predictionAndLabels.map { case (predictions, labels) => if (predictions.size > 0) { predictions.intersect(labels).size.toDouble / predictions.size @@ -79,6 +83,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns document-based recall averaged by the number of documents */ + @Since("1.2.0") lazy val recall: Double = predictionAndLabels.map { case (predictions, labels) => labels.intersect(predictions).size.toDouble / labels.size }.sum / numDocs @@ -86,6 +91,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] /** * Returns document-based f1-measure averaged by the number of documents */ + @Since("1.2.0") lazy val f1Measure: Double = predictionAndLabels.map { case (predictions, labels) => 2.0 * predictions.intersect(labels).size / (predictions.size + labels.size) }.sum / numDocs @@ -143,6 +149,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns micro-averaged label-based precision * (equals to micro-averaged document-based precision) */ + @Since("1.2.0") lazy val microPrecision: Double = { val sumFp = fpPerClass.foldLeft(0L){ case(cum, (_, fp)) => cum + fp} sumTp.toDouble / (sumTp + sumFp) @@ -152,6 +159,7 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns micro-averaged label-based recall * (equals to micro-averaged document-based recall) */ + @Since("1.2.0") lazy val microRecall: Double = { val sumFn = fnPerClass.foldLeft(0.0){ case(cum, (_, fn)) => cum + fn} sumTp.toDouble / (sumTp + sumFn) @@ -161,10 +169,12 @@ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])] * Returns micro-averaged label-based f1-measure * (equals to micro-averaged document-based f1-measure) */ + @Since("1.2.0") lazy val microF1Measure: Double = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass) /** * Returns the sequence of labels in ascending order */ + @Since("1.2.0") lazy val labels: Array[Double] = tpPerClass.keys.toArray.sorted } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index 36a6c357c3897..799ebb980ef01 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -32,7 +32,8 @@ import org.apache.spark.sql.DataFrame */ @Since("1.2.0") @Experimental -class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extends Logging { +class RegressionMetrics @Since("1.2.0") ( + predictionAndObservations: RDD[(Double, Double)]) extends Logging { /** * An auxiliary constructor taking a DataFrame. From ab431f8a970b85fba34ccb506c0f8815e55c63bf Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 25 Aug 2015 20:07:56 -0700 Subject: [PATCH 085/802] [SPARK-10238] [MLLIB] update since versions in mllib.linalg Same as #8421 but for `mllib.linalg`. cc dbtsai Author: Xiangrui Meng Closes #8440 from mengxr/SPARK-10238 and squashes the following commits: b38437e [Xiangrui Meng] update since versions in mllib.linalg --- .../apache/spark/mllib/linalg/Matrices.scala | 44 ++++++++++++------- .../linalg/SingularValueDecomposition.scala | 1 + .../apache/spark/mllib/linalg/Vectors.scala | 25 ++++++++--- .../linalg/distributed/BlockMatrix.scala | 10 +++-- .../linalg/distributed/CoordinateMatrix.scala | 4 +- .../distributed/DistributedMatrix.scala | 2 + .../linalg/distributed/IndexedRowMatrix.scala | 4 +- .../mllib/linalg/distributed/RowMatrix.scala | 5 ++- 8 files changed, 64 insertions(+), 31 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 28b5b4637bf17..c02ba426fcc3a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -32,18 +32,23 @@ import org.apache.spark.sql.types._ * Trait for a local matrix. */ @SQLUserDefinedType(udt = classOf[MatrixUDT]) +@Since("1.0.0") sealed trait Matrix extends Serializable { /** Number of rows. */ + @Since("1.0.0") def numRows: Int /** Number of columns. */ + @Since("1.0.0") def numCols: Int /** Flag that keeps track whether the matrix is transposed or not. False by default. */ + @Since("1.3.0") val isTransposed: Boolean = false /** Converts to a dense array in column major. */ + @Since("1.0.0") def toArray: Array[Double] = { val newArray = new Array[Double](numRows * numCols) foreachActive { (i, j, v) => @@ -56,6 +61,7 @@ sealed trait Matrix extends Serializable { private[mllib] def toBreeze: BM[Double] /** Gets the (i, j)-th element. */ + @Since("1.3.0") def apply(i: Int, j: Int): Double /** Return the index for the (i, j)-th element in the backing array. */ @@ -65,12 +71,15 @@ sealed trait Matrix extends Serializable { private[mllib] def update(i: Int, j: Int, v: Double): Unit /** Get a deep copy of the matrix. */ + @Since("1.2.0") def copy: Matrix /** Transpose the Matrix. Returns a new `Matrix` instance sharing the same underlying data. */ + @Since("1.3.0") def transpose: Matrix /** Convenience method for `Matrix`-`DenseMatrix` multiplication. */ + @Since("1.2.0") def multiply(y: DenseMatrix): DenseMatrix = { val C: DenseMatrix = DenseMatrix.zeros(numRows, y.numCols) BLAS.gemm(1.0, this, y, 0.0, C) @@ -78,11 +87,13 @@ sealed trait Matrix extends Serializable { } /** Convenience method for `Matrix`-`DenseVector` multiplication. For binary compatibility. */ + @Since("1.2.0") def multiply(y: DenseVector): DenseVector = { multiply(y.asInstanceOf[Vector]) } /** Convenience method for `Matrix`-`Vector` multiplication. */ + @Since("1.4.0") def multiply(y: Vector): DenseVector = { val output = new DenseVector(new Array[Double](numRows)) BLAS.gemv(1.0, this, y, 0.0, output) @@ -93,6 +104,7 @@ sealed trait Matrix extends Serializable { override def toString: String = toBreeze.toString() /** A human readable representation of the matrix with maximum lines and width */ + @Since("1.4.0") def toString(maxLines: Int, maxLineWidth: Int): String = toBreeze.toString(maxLines, maxLineWidth) /** Map the values of this matrix using a function. Generates a new matrix. Performs the @@ -118,11 +130,13 @@ sealed trait Matrix extends Serializable { /** * Find the number of non-zero active values. */ + @Since("1.5.0") def numNonzeros: Int /** * Find the number of values stored explicitly. These values can be zero as well. */ + @Since("1.5.0") def numActives: Int } @@ -230,11 +244,11 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { */ @Since("1.0.0") @SQLUserDefinedType(udt = classOf[MatrixUDT]) -class DenseMatrix( - val numRows: Int, - val numCols: Int, - val values: Array[Double], - override val isTransposed: Boolean) extends Matrix { +class DenseMatrix @Since("1.3.0") ( + @Since("1.0.0") val numRows: Int, + @Since("1.0.0") val numCols: Int, + @Since("1.0.0") val values: Array[Double], + @Since("1.3.0") override val isTransposed: Boolean) extends Matrix { require(values.length == numRows * numCols, "The number of values supplied doesn't match the " + s"size of the matrix! values.length: ${values.length}, numRows * numCols: ${numRows * numCols}") @@ -254,7 +268,7 @@ class DenseMatrix( * @param numCols number of columns * @param values matrix entries in column major */ - @Since("1.3.0") + @Since("1.0.0") def this(numRows: Int, numCols: Int, values: Array[Double]) = this(numRows, numCols, values, false) @@ -491,13 +505,13 @@ object DenseMatrix { */ @Since("1.2.0") @SQLUserDefinedType(udt = classOf[MatrixUDT]) -class SparseMatrix( - val numRows: Int, - val numCols: Int, - val colPtrs: Array[Int], - val rowIndices: Array[Int], - val values: Array[Double], - override val isTransposed: Boolean) extends Matrix { +class SparseMatrix @Since("1.3.0") ( + @Since("1.2.0") val numRows: Int, + @Since("1.2.0") val numCols: Int, + @Since("1.2.0") val colPtrs: Array[Int], + @Since("1.2.0") val rowIndices: Array[Int], + @Since("1.2.0") val values: Array[Double], + @Since("1.3.0") override val isTransposed: Boolean) extends Matrix { require(values.length == rowIndices.length, "The number of row indices and values don't match! " + s"values.length: ${values.length}, rowIndices.length: ${rowIndices.length}") @@ -527,7 +541,7 @@ class SparseMatrix( * order for each column * @param values non-zero matrix entries in column major */ - @Since("1.3.0") + @Since("1.2.0") def this( numRows: Int, numCols: Int, @@ -549,8 +563,6 @@ class SparseMatrix( } } - /** - */ @Since("1.3.0") override def apply(i: Int, j: Int): Double = { val ind = index(i, j) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala index a37aca99d5e72..4dcf8f28f2023 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala @@ -31,6 +31,7 @@ case class SingularValueDecomposition[UType, VType](U: UType, s: Vector, V: VTyp * :: Experimental :: * Represents QR factors. */ +@Since("1.5.0") @Experimental case class QRDecomposition[QType, RType](Q: QType, R: RType) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 3d577edbe23e1..06ebb15869909 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -38,16 +38,19 @@ import org.apache.spark.sql.types._ * Note: Users should not implement this interface. */ @SQLUserDefinedType(udt = classOf[VectorUDT]) +@Since("1.0.0") sealed trait Vector extends Serializable { /** * Size of the vector. */ + @Since("1.0.0") def size: Int /** * Converts the instance to a double array. */ + @Since("1.0.0") def toArray: Array[Double] override def equals(other: Any): Boolean = { @@ -99,11 +102,13 @@ sealed trait Vector extends Serializable { * Gets the value of the ith element. * @param i index */ + @Since("1.1.0") def apply(i: Int): Double = toBreeze(i) /** * Makes a deep copy of this vector. */ + @Since("1.1.0") def copy: Vector = { throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.") } @@ -121,26 +126,31 @@ sealed trait Vector extends Serializable { * Number of active entries. An "active entry" is an element which is explicitly stored, * regardless of its value. Note that inactive entries have value 0. */ + @Since("1.4.0") def numActives: Int /** * Number of nonzero elements. This scans all active values and count nonzeros. */ + @Since("1.4.0") def numNonzeros: Int /** * Converts this vector to a sparse vector with all explicit zeros removed. */ + @Since("1.4.0") def toSparse: SparseVector /** * Converts this vector to a dense vector. */ + @Since("1.4.0") def toDense: DenseVector = new DenseVector(this.toArray) /** * Returns a vector in either dense or sparse format, whichever uses less storage. */ + @Since("1.4.0") def compressed: Vector = { val nnz = numNonzeros // A dense vector needs 8 * size + 8 bytes, while a sparse vector needs 12 * nnz + 20 bytes. @@ -155,6 +165,7 @@ sealed trait Vector extends Serializable { * Find the index of a maximal element. Returns the first maximal element in case of a tie. * Returns -1 if vector has length 0. */ + @Since("1.5.0") def argmax: Int } @@ -532,7 +543,8 @@ object Vectors { */ @Since("1.0.0") @SQLUserDefinedType(udt = classOf[VectorUDT]) -class DenseVector(val values: Array[Double]) extends Vector { +class DenseVector @Since("1.0.0") ( + @Since("1.0.0") val values: Array[Double]) extends Vector { @Since("1.0.0") override def size: Int = values.length @@ -632,7 +644,9 @@ class DenseVector(val values: Array[Double]) extends Vector { @Since("1.3.0") object DenseVector { + /** Extracts the value array from a dense vector. */ + @Since("1.3.0") def unapply(dv: DenseVector): Option[Array[Double]] = Some(dv.values) } @@ -645,10 +659,10 @@ object DenseVector { */ @Since("1.0.0") @SQLUserDefinedType(udt = classOf[VectorUDT]) -class SparseVector( - override val size: Int, - val indices: Array[Int], - val values: Array[Double]) extends Vector { +class SparseVector @Since("1.0.0") ( + @Since("1.0.0") override val size: Int, + @Since("1.0.0") val indices: Array[Int], + @Since("1.0.0") val values: Array[Double]) extends Vector { require(indices.length == values.length, "Sparse vectors require that the dimension of the" + s" indices match the dimension of the values. You provided ${indices.length} indices and " + @@ -819,6 +833,7 @@ class SparseVector( @Since("1.3.0") object SparseVector { + @Since("1.3.0") def unapply(sv: SparseVector): Option[(Int, Array[Int], Array[Double])] = Some((sv.size, sv.indices, sv.values)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index 94376c24a7ac6..a33b6137cf9cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -131,10 +131,10 @@ private[mllib] object GridPartitioner { */ @Since("1.3.0") @Experimental -class BlockMatrix( - val blocks: RDD[((Int, Int), Matrix)], - val rowsPerBlock: Int, - val colsPerBlock: Int, +class BlockMatrix @Since("1.3.0") ( + @Since("1.3.0") val blocks: RDD[((Int, Int), Matrix)], + @Since("1.3.0") val rowsPerBlock: Int, + @Since("1.3.0") val colsPerBlock: Int, private var nRows: Long, private var nCols: Long) extends DistributedMatrix with Logging { @@ -171,7 +171,9 @@ class BlockMatrix( nCols } + @Since("1.3.0") val numRowBlocks = math.ceil(numRows() * 1.0 / rowsPerBlock).toInt + @Since("1.3.0") val numColBlocks = math.ceil(numCols() * 1.0 / colsPerBlock).toInt private[mllib] def createPartitioner(): GridPartitioner = diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala index 4bb27ec840902..644f293d88a75 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala @@ -46,8 +46,8 @@ case class MatrixEntry(i: Long, j: Long, value: Double) */ @Since("1.0.0") @Experimental -class CoordinateMatrix( - val entries: RDD[MatrixEntry], +class CoordinateMatrix @Since("1.0.0") ( + @Since("1.0.0") val entries: RDD[MatrixEntry], private var nRows: Long, private var nCols: Long) extends DistributedMatrix { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala index e51327ebb7b58..db3433a5e2456 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala @@ -28,9 +28,11 @@ import org.apache.spark.annotation.Since trait DistributedMatrix extends Serializable { /** Gets or computes the number of rows. */ + @Since("1.0.0") def numRows(): Long /** Gets or computes the number of columns. */ + @Since("1.0.0") def numCols(): Long /** Collects data and assembles a local dense breeze matrix (for test only). */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index 6d2c05a47d049..b20ea0dc50da5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -45,8 +45,8 @@ case class IndexedRow(index: Long, vector: Vector) */ @Since("1.0.0") @Experimental -class IndexedRowMatrix( - val rows: RDD[IndexedRow], +class IndexedRowMatrix @Since("1.0.0") ( + @Since("1.0.0") val rows: RDD[IndexedRow], private var nRows: Long, private var nCols: Int) extends DistributedMatrix { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 78036eba5c3e6..9a423ddafdc09 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -47,8 +47,8 @@ import org.apache.spark.storage.StorageLevel */ @Since("1.0.0") @Experimental -class RowMatrix( - val rows: RDD[Vector], +class RowMatrix @Since("1.0.0") ( + @Since("1.0.0") val rows: RDD[Vector], private var nRows: Long, private var nCols: Int) extends DistributedMatrix with Logging { @@ -519,6 +519,7 @@ class RowMatrix( * @param computeQ whether to computeQ * @return QRDecomposition(Q, R), Q = null if computeQ = false. */ + @Since("1.5.0") def tallSkinnyQR(computeQ: Boolean = false): QRDecomposition[RowMatrix, Matrix] = { val col = numCols().toInt // split rows horizontally into smaller matrices, and compute QR for each of them From c3a54843c0c8a14059da4e6716c1ad45c69bbe6c Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 25 Aug 2015 22:31:23 -0700 Subject: [PATCH 086/802] [SPARK-10240] [SPARK-10242] [MLLIB] update since versions in mlilb.random and mllib.stat The same as #8241 but for `mllib.stat` and `mllib.random`. cc feynmanliang Author: Xiangrui Meng Closes #8439 from mengxr/SPARK-10242. --- .../mllib/random/RandomDataGenerator.scala | 43 ++++++++++-- .../spark/mllib/random/RandomRDDs.scala | 69 ++++++++++++++++--- .../distribution/MultivariateGaussian.scala | 6 +- .../spark/mllib/stat/test/TestResult.scala | 24 ++++--- 4 files changed, 117 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala index 9349ecaa13f56..a2d85a68cd327 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.random import org.apache.commons.math3.distribution.{ExponentialDistribution, GammaDistribution, LogNormalDistribution, PoissonDistribution} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{Since, DeveloperApi} import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} /** @@ -28,17 +28,20 @@ import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} * Trait for random data generators that generate i.i.d. data. */ @DeveloperApi +@Since("1.1.0") trait RandomDataGenerator[T] extends Pseudorandom with Serializable { /** * Returns an i.i.d. sample as a generic type from an underlying distribution. */ + @Since("1.1.0") def nextValue(): T /** * Returns a copy of the RandomDataGenerator with a new instance of the rng object used in the * class when applicable for non-locking concurrent usage. */ + @Since("1.1.0") def copy(): RandomDataGenerator[T] } @@ -47,17 +50,21 @@ trait RandomDataGenerator[T] extends Pseudorandom with Serializable { * Generates i.i.d. samples from U[0.0, 1.0] */ @DeveloperApi +@Since("1.1.0") class UniformGenerator extends RandomDataGenerator[Double] { // XORShiftRandom for better performance. Thread safety isn't necessary here. private val random = new XORShiftRandom() + @Since("1.1.0") override def nextValue(): Double = { random.nextDouble() } + @Since("1.1.0") override def setSeed(seed: Long): Unit = random.setSeed(seed) + @Since("1.1.0") override def copy(): UniformGenerator = new UniformGenerator() } @@ -66,17 +73,21 @@ class UniformGenerator extends RandomDataGenerator[Double] { * Generates i.i.d. samples from the standard normal distribution. */ @DeveloperApi +@Since("1.1.0") class StandardNormalGenerator extends RandomDataGenerator[Double] { // XORShiftRandom for better performance. Thread safety isn't necessary here. private val random = new XORShiftRandom() + @Since("1.1.0") override def nextValue(): Double = { random.nextGaussian() } + @Since("1.1.0") override def setSeed(seed: Long): Unit = random.setSeed(seed) + @Since("1.1.0") override def copy(): StandardNormalGenerator = new StandardNormalGenerator() } @@ -87,16 +98,21 @@ class StandardNormalGenerator extends RandomDataGenerator[Double] { * @param mean mean for the Poisson distribution. */ @DeveloperApi -class PoissonGenerator(val mean: Double) extends RandomDataGenerator[Double] { +@Since("1.1.0") +class PoissonGenerator @Since("1.1.0") ( + @Since("1.1.0") val mean: Double) extends RandomDataGenerator[Double] { private val rng = new PoissonDistribution(mean) + @Since("1.1.0") override def nextValue(): Double = rng.sample() + @Since("1.1.0") override def setSeed(seed: Long) { rng.reseedRandomGenerator(seed) } + @Since("1.1.0") override def copy(): PoissonGenerator = new PoissonGenerator(mean) } @@ -107,16 +123,21 @@ class PoissonGenerator(val mean: Double) extends RandomDataGenerator[Double] { * @param mean mean for the exponential distribution. */ @DeveloperApi -class ExponentialGenerator(val mean: Double) extends RandomDataGenerator[Double] { +@Since("1.3.0") +class ExponentialGenerator @Since("1.3.0") ( + @Since("1.3.0") val mean: Double) extends RandomDataGenerator[Double] { private val rng = new ExponentialDistribution(mean) + @Since("1.3.0") override def nextValue(): Double = rng.sample() + @Since("1.3.0") override def setSeed(seed: Long) { rng.reseedRandomGenerator(seed) } + @Since("1.3.0") override def copy(): ExponentialGenerator = new ExponentialGenerator(mean) } @@ -128,16 +149,22 @@ class ExponentialGenerator(val mean: Double) extends RandomDataGenerator[Double] * @param scale scale for the gamma distribution */ @DeveloperApi -class GammaGenerator(val shape: Double, val scale: Double) extends RandomDataGenerator[Double] { +@Since("1.3.0") +class GammaGenerator @Since("1.3.0") ( + @Since("1.3.0") val shape: Double, + @Since("1.3.0") val scale: Double) extends RandomDataGenerator[Double] { private val rng = new GammaDistribution(shape, scale) + @Since("1.3.0") override def nextValue(): Double = rng.sample() + @Since("1.3.0") override def setSeed(seed: Long) { rng.reseedRandomGenerator(seed) } + @Since("1.3.0") override def copy(): GammaGenerator = new GammaGenerator(shape, scale) } @@ -150,15 +177,21 @@ class GammaGenerator(val shape: Double, val scale: Double) extends RandomDataGen * @param std standard deviation for the log normal distribution */ @DeveloperApi -class LogNormalGenerator(val mean: Double, val std: Double) extends RandomDataGenerator[Double] { +@Since("1.3.0") +class LogNormalGenerator @Since("1.3.0") ( + @Since("1.3.0") val mean: Double, + @Since("1.3.0") val std: Double) extends RandomDataGenerator[Double] { private val rng = new LogNormalDistribution(mean, std) + @Since("1.3.0") override def nextValue(): Double = rng.sample() + @Since("1.3.0") override def setSeed(seed: Long) { rng.reseedRandomGenerator(seed) } + @Since("1.3.0") override def copy(): LogNormalGenerator = new LogNormalGenerator(mean, std) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index 174d5e0f6c9f0..4dd5ea214d678 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.random import scala.reflect.ClassTag import org.apache.spark.SparkContext -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD, JavaSparkContext} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.rdd.{RandomRDD, RandomVectorRDD} @@ -32,6 +32,7 @@ import org.apache.spark.util.Utils * Generator methods for creating RDDs comprised of `i.i.d.` samples from some distribution. */ @Experimental +@Since("1.1.0") object RandomRDDs { /** @@ -46,6 +47,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Double] comprised of `i.i.d.` samples ~ `U(0.0, 1.0)`. */ + @Since("1.1.0") def uniformRDD( sc: SparkContext, size: Long, @@ -58,6 +60,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#uniformRDD]]. */ + @Since("1.1.0") def uniformJavaRDD( jsc: JavaSparkContext, size: Long, @@ -69,6 +72,7 @@ object RandomRDDs { /** * [[RandomRDDs#uniformJavaRDD]] with the default seed. */ + @Since("1.1.0") def uniformJavaRDD(jsc: JavaSparkContext, size: Long, numPartitions: Int): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(uniformRDD(jsc.sc, size, numPartitions)) } @@ -76,6 +80,7 @@ object RandomRDDs { /** * [[RandomRDDs#uniformJavaRDD]] with the default number of partitions and the default seed. */ + @Since("1.1.0") def uniformJavaRDD(jsc: JavaSparkContext, size: Long): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(uniformRDD(jsc.sc, size)) } @@ -92,6 +97,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Double] comprised of `i.i.d.` samples ~ N(0.0, 1.0). */ + @Since("1.1.0") def normalRDD( sc: SparkContext, size: Long, @@ -104,6 +110,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#normalRDD]]. */ + @Since("1.1.0") def normalJavaRDD( jsc: JavaSparkContext, size: Long, @@ -115,6 +122,7 @@ object RandomRDDs { /** * [[RandomRDDs#normalJavaRDD]] with the default seed. */ + @Since("1.1.0") def normalJavaRDD(jsc: JavaSparkContext, size: Long, numPartitions: Int): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(normalRDD(jsc.sc, size, numPartitions)) } @@ -122,6 +130,7 @@ object RandomRDDs { /** * [[RandomRDDs#normalJavaRDD]] with the default number of partitions and the default seed. */ + @Since("1.1.0") def normalJavaRDD(jsc: JavaSparkContext, size: Long): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(normalRDD(jsc.sc, size)) } @@ -137,6 +146,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ + @Since("1.1.0") def poissonRDD( sc: SparkContext, mean: Double, @@ -150,6 +160,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#poissonRDD]]. */ + @Since("1.1.0") def poissonJavaRDD( jsc: JavaSparkContext, mean: Double, @@ -162,6 +173,7 @@ object RandomRDDs { /** * [[RandomRDDs#poissonJavaRDD]] with the default seed. */ + @Since("1.1.0") def poissonJavaRDD( jsc: JavaSparkContext, mean: Double, @@ -173,6 +185,7 @@ object RandomRDDs { /** * [[RandomRDDs#poissonJavaRDD]] with the default number of partitions and the default seed. */ + @Since("1.1.0") def poissonJavaRDD(jsc: JavaSparkContext, mean: Double, size: Long): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(poissonRDD(jsc.sc, mean, size)) } @@ -188,6 +201,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ + @Since("1.3.0") def exponentialRDD( sc: SparkContext, mean: Double, @@ -201,6 +215,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#exponentialRDD]]. */ + @Since("1.3.0") def exponentialJavaRDD( jsc: JavaSparkContext, mean: Double, @@ -213,6 +228,7 @@ object RandomRDDs { /** * [[RandomRDDs#exponentialJavaRDD]] with the default seed. */ + @Since("1.3.0") def exponentialJavaRDD( jsc: JavaSparkContext, mean: Double, @@ -224,6 +240,7 @@ object RandomRDDs { /** * [[RandomRDDs#exponentialJavaRDD]] with the default number of partitions and the default seed. */ + @Since("1.3.0") def exponentialJavaRDD(jsc: JavaSparkContext, mean: Double, size: Long): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(exponentialRDD(jsc.sc, mean, size)) } @@ -240,6 +257,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ + @Since("1.3.0") def gammaRDD( sc: SparkContext, shape: Double, @@ -254,6 +272,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#gammaRDD]]. */ + @Since("1.3.0") def gammaJavaRDD( jsc: JavaSparkContext, shape: Double, @@ -267,6 +286,7 @@ object RandomRDDs { /** * [[RandomRDDs#gammaJavaRDD]] with the default seed. */ + @Since("1.3.0") def gammaJavaRDD( jsc: JavaSparkContext, shape: Double, @@ -279,11 +299,12 @@ object RandomRDDs { /** * [[RandomRDDs#gammaJavaRDD]] with the default number of partitions and the default seed. */ + @Since("1.3.0") def gammaJavaRDD( - jsc: JavaSparkContext, - shape: Double, - scale: Double, - size: Long): JavaDoubleRDD = { + jsc: JavaSparkContext, + shape: Double, + scale: Double, + size: Long): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(gammaRDD(jsc.sc, shape, scale, size)) } @@ -299,6 +320,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Double] comprised of `i.i.d.` samples ~ Pois(mean). */ + @Since("1.3.0") def logNormalRDD( sc: SparkContext, mean: Double, @@ -313,6 +335,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#logNormalRDD]]. */ + @Since("1.3.0") def logNormalJavaRDD( jsc: JavaSparkContext, mean: Double, @@ -326,6 +349,7 @@ object RandomRDDs { /** * [[RandomRDDs#logNormalJavaRDD]] with the default seed. */ + @Since("1.3.0") def logNormalJavaRDD( jsc: JavaSparkContext, mean: Double, @@ -338,11 +362,12 @@ object RandomRDDs { /** * [[RandomRDDs#logNormalJavaRDD]] with the default number of partitions and the default seed. */ + @Since("1.3.0") def logNormalJavaRDD( - jsc: JavaSparkContext, - mean: Double, - std: Double, - size: Long): JavaDoubleRDD = { + jsc: JavaSparkContext, + mean: Double, + std: Double, + size: Long): JavaDoubleRDD = { JavaDoubleRDD.fromRDD(logNormalRDD(jsc.sc, mean, std, size)) } @@ -359,6 +384,7 @@ object RandomRDDs { * @return RDD[Double] comprised of `i.i.d.` samples produced by generator. */ @DeveloperApi + @Since("1.1.0") def randomRDD[T: ClassTag]( sc: SparkContext, generator: RandomDataGenerator[T], @@ -381,6 +407,7 @@ object RandomRDDs { * @param seed Seed for the RNG that generates the seed for the generator in each partition. * @return RDD[Vector] with vectors containing i.i.d samples ~ `U(0.0, 1.0)`. */ + @Since("1.1.0") def uniformVectorRDD( sc: SparkContext, numRows: Long, @@ -394,6 +421,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#uniformVectorRDD]]. */ + @Since("1.1.0") def uniformJavaVectorRDD( jsc: JavaSparkContext, numRows: Long, @@ -406,6 +434,7 @@ object RandomRDDs { /** * [[RandomRDDs#uniformJavaVectorRDD]] with the default seed. */ + @Since("1.1.0") def uniformJavaVectorRDD( jsc: JavaSparkContext, numRows: Long, @@ -417,6 +446,7 @@ object RandomRDDs { /** * [[RandomRDDs#uniformJavaVectorRDD]] with the default number of partitions and the default seed. */ + @Since("1.1.0") def uniformJavaVectorRDD( jsc: JavaSparkContext, numRows: Long, @@ -435,6 +465,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ `N(0.0, 1.0)`. */ + @Since("1.1.0") def normalVectorRDD( sc: SparkContext, numRows: Long, @@ -448,6 +479,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#normalVectorRDD]]. */ + @Since("1.1.0") def normalJavaVectorRDD( jsc: JavaSparkContext, numRows: Long, @@ -460,6 +492,7 @@ object RandomRDDs { /** * [[RandomRDDs#normalJavaVectorRDD]] with the default seed. */ + @Since("1.1.0") def normalJavaVectorRDD( jsc: JavaSparkContext, numRows: Long, @@ -471,6 +504,7 @@ object RandomRDDs { /** * [[RandomRDDs#normalJavaVectorRDD]] with the default number of partitions and the default seed. */ + @Since("1.1.0") def normalJavaVectorRDD( jsc: JavaSparkContext, numRows: Long, @@ -491,6 +525,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Vector] with vectors containing `i.i.d.` samples. */ + @Since("1.3.0") def logNormalVectorRDD( sc: SparkContext, mean: Double, @@ -507,6 +542,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#logNormalVectorRDD]]. */ + @Since("1.3.0") def logNormalJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -521,6 +557,7 @@ object RandomRDDs { /** * [[RandomRDDs#logNormalJavaVectorRDD]] with the default seed. */ + @Since("1.3.0") def logNormalJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -535,6 +572,7 @@ object RandomRDDs { * [[RandomRDDs#logNormalJavaVectorRDD]] with the default number of partitions and * the default seed. */ + @Since("1.3.0") def logNormalJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -556,6 +594,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ Pois(mean). */ + @Since("1.1.0") def poissonVectorRDD( sc: SparkContext, mean: Double, @@ -570,6 +609,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#poissonVectorRDD]]. */ + @Since("1.1.0") def poissonJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -583,6 +623,7 @@ object RandomRDDs { /** * [[RandomRDDs#poissonJavaVectorRDD]] with the default seed. */ + @Since("1.1.0") def poissonJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -595,6 +636,7 @@ object RandomRDDs { /** * [[RandomRDDs#poissonJavaVectorRDD]] with the default number of partitions and the default seed. */ + @Since("1.1.0") def poissonJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -615,6 +657,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ Exp(mean). */ + @Since("1.3.0") def exponentialVectorRDD( sc: SparkContext, mean: Double, @@ -630,6 +673,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#exponentialVectorRDD]]. */ + @Since("1.3.0") def exponentialJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -643,6 +687,7 @@ object RandomRDDs { /** * [[RandomRDDs#exponentialJavaVectorRDD]] with the default seed. */ + @Since("1.3.0") def exponentialJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -656,6 +701,7 @@ object RandomRDDs { * [[RandomRDDs#exponentialJavaVectorRDD]] with the default number of partitions * and the default seed. */ + @Since("1.3.0") def exponentialJavaVectorRDD( jsc: JavaSparkContext, mean: Double, @@ -678,6 +724,7 @@ object RandomRDDs { * @param seed Random seed (default: a random long integer). * @return RDD[Vector] with vectors containing `i.i.d.` samples ~ Exp(mean). */ + @Since("1.3.0") def gammaVectorRDD( sc: SparkContext, shape: Double, @@ -693,6 +740,7 @@ object RandomRDDs { /** * Java-friendly version of [[RandomRDDs#gammaVectorRDD]]. */ + @Since("1.3.0") def gammaJavaVectorRDD( jsc: JavaSparkContext, shape: Double, @@ -707,6 +755,7 @@ object RandomRDDs { /** * [[RandomRDDs#gammaJavaVectorRDD]] with the default seed. */ + @Since("1.3.0") def gammaJavaVectorRDD( jsc: JavaSparkContext, shape: Double, @@ -720,6 +769,7 @@ object RandomRDDs { /** * [[RandomRDDs#gammaJavaVectorRDD]] with the default number of partitions and the default seed. */ + @Since("1.3.0") def gammaJavaVectorRDD( jsc: JavaSparkContext, shape: Double, @@ -744,6 +794,7 @@ object RandomRDDs { * @return RDD[Vector] with vectors containing `i.i.d.` samples produced by generator. */ @DeveloperApi + @Since("1.1.0") def randomVectorRDD(sc: SparkContext, generator: RandomDataGenerator[Double], numRows: Long, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index bd4d81390bfae..92a5af708d04b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -35,9 +35,9 @@ import org.apache.spark.mllib.util.MLUtils */ @Since("1.3.0") @DeveloperApi -class MultivariateGaussian ( - val mu: Vector, - val sigma: Matrix) extends Serializable { +class MultivariateGaussian @Since("1.3.0") ( + @Since("1.3.0") val mu: Vector, + @Since("1.3.0") val sigma: Matrix) extends Serializable { require(sigma.numCols == sigma.numRows, "Covariance matrix must be square") require(mu.size == sigma.numCols, "Mean vector length must match covariance matrix size") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala index f44be13706695..d01b3707be944 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.stat.test -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} /** * :: Experimental :: @@ -25,28 +25,33 @@ import org.apache.spark.annotation.Experimental * @tparam DF Return type of `degreesOfFreedom`. */ @Experimental +@Since("1.1.0") trait TestResult[DF] { /** * The probability of obtaining a test statistic result at least as extreme as the one that was * actually observed, assuming that the null hypothesis is true. */ + @Since("1.1.0") def pValue: Double /** * Returns the degree(s) of freedom of the hypothesis test. * Return type should be Number(e.g. Int, Double) or tuples of Numbers for toString compatibility. */ + @Since("1.1.0") def degreesOfFreedom: DF /** * Test statistic. */ + @Since("1.1.0") def statistic: Double /** * Null hypothesis of the test. */ + @Since("1.1.0") def nullHypothesis: String /** @@ -78,11 +83,12 @@ trait TestResult[DF] { * Object containing the test results for the chi-squared hypothesis test. */ @Experimental +@Since("1.1.0") class ChiSqTestResult private[stat] (override val pValue: Double, - override val degreesOfFreedom: Int, - override val statistic: Double, - val method: String, - override val nullHypothesis: String) extends TestResult[Int] { + @Since("1.1.0") override val degreesOfFreedom: Int, + @Since("1.1.0") override val statistic: Double, + @Since("1.1.0") val method: String, + @Since("1.1.0") override val nullHypothesis: String) extends TestResult[Int] { override def toString: String = { "Chi squared test summary:\n" + @@ -96,11 +102,13 @@ class ChiSqTestResult private[stat] (override val pValue: Double, * Object containing the test results for the Kolmogorov-Smirnov test. */ @Experimental +@Since("1.5.0") class KolmogorovSmirnovTestResult private[stat] ( - override val pValue: Double, - override val statistic: Double, - override val nullHypothesis: String) extends TestResult[Int] { + @Since("1.5.0") override val pValue: Double, + @Since("1.5.0") override val statistic: Double, + @Since("1.5.0") override val nullHypothesis: String) extends TestResult[Int] { + @Since("1.5.0") override val degreesOfFreedom = 0 override def toString: String = { From d703372f86d6a59383ba8569fcd9d379849cffbf Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 25 Aug 2015 22:33:48 -0700 Subject: [PATCH 087/802] [SPARK-10234] [MLLIB] update since version in mllib.clustering Same as #8421 but for `mllib.clustering`. cc feynmanliang yu-iskw Author: Xiangrui Meng Closes #8435 from mengxr/SPARK-10234. --- .../mllib/clustering/GaussianMixture.scala | 1 + .../clustering/GaussianMixtureModel.scala | 8 +++--- .../spark/mllib/clustering/KMeans.scala | 1 + .../spark/mllib/clustering/KMeansModel.scala | 4 +-- .../spark/mllib/clustering/LDAModel.scala | 28 ++++++++++++++----- .../clustering/PowerIterationClustering.scala | 10 +++++-- .../mllib/clustering/StreamingKMeans.scala | 15 +++++----- 7 files changed, 44 insertions(+), 23 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index daa947e81d44d..f82bd82c20371 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -53,6 +53,7 @@ import org.apache.spark.util.Utils * @param maxIterations The maximum number of iterations to perform */ @Experimental +@Since("1.3.0") class GaussianMixture private ( private var k: Int, private var convergenceTol: Double, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 1a10a8b624218..7f6163e04bf17 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -46,9 +46,9 @@ import org.apache.spark.sql.{SQLContext, Row} */ @Since("1.3.0") @Experimental -class GaussianMixtureModel( - val weights: Array[Double], - val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable { +class GaussianMixtureModel @Since("1.3.0") ( + @Since("1.3.0") val weights: Array[Double], + @Since("1.3.0") val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable { require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match") @@ -178,7 +178,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { (weight, new MultivariateGaussian(mu, sigma)) }.unzip - return new GaussianMixtureModel(weights.toArray, gaussians.toArray) + new GaussianMixtureModel(weights.toArray, gaussians.toArray) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 3e9545a74bef3..46920fffe6e1a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -37,6 +37,7 @@ import org.apache.spark.util.random.XORShiftRandom * This is an iterative algorithm that will make multiple passes over the data, so any RDDs given * to it should be cached by the user. */ +@Since("0.8.0") class KMeans private ( private var k: Int, private var maxIterations: Int, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index e425ecdd481c6..a741584982725 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -37,8 +37,8 @@ import org.apache.spark.sql.Row * A clustering model for K-means. Each point belongs to the cluster with the closest center. */ @Since("0.8.0") -class KMeansModel ( - val clusterCenters: Array[Vector]) extends Saveable with Serializable with PMMLExportable { +class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vector]) + extends Saveable with Serializable with PMMLExportable { /** * A Java-friendly constructor that takes an Iterable of Vectors. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 432bbedc8d6f8..15129e0dd5a91 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -43,12 +43,15 @@ import org.apache.spark.util.BoundedPriorityQueue * including local and distributed data structures. */ @Experimental +@Since("1.3.0") abstract class LDAModel private[clustering] extends Saveable { /** Number of topics */ + @Since("1.3.0") def k: Int /** Vocabulary size (number of terms or terms in the vocabulary) */ + @Since("1.3.0") def vocabSize: Int /** @@ -57,6 +60,7 @@ abstract class LDAModel private[clustering] extends Saveable { * * This is the parameter to a Dirichlet distribution. */ + @Since("1.5.0") def docConcentration: Vector /** @@ -68,6 +72,7 @@ abstract class LDAModel private[clustering] extends Saveable { * Note: The topics' distributions over terms are called "beta" in the original LDA paper * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. */ + @Since("1.5.0") def topicConcentration: Double /** @@ -81,6 +86,7 @@ abstract class LDAModel private[clustering] extends Saveable { * This is a matrix of size vocabSize x k, where each column is a topic. * No guarantees are given about the ordering of the topics. */ + @Since("1.3.0") def topicsMatrix: Matrix /** @@ -91,6 +97,7 @@ abstract class LDAModel private[clustering] extends Saveable { * (term indices, term weights in topic). * Each topic's terms are sorted in order of decreasing weight. */ + @Since("1.3.0") def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] /** @@ -102,6 +109,7 @@ abstract class LDAModel private[clustering] extends Saveable { * (term indices, term weights in topic). * Each topic's terms are sorted in order of decreasing weight. */ + @Since("1.3.0") def describeTopics(): Array[(Array[Int], Array[Double])] = describeTopics(vocabSize) /* TODO (once LDA can be trained with Strings or given a dictionary) @@ -185,10 +193,11 @@ abstract class LDAModel private[clustering] extends Saveable { * @param topics Inferred topics (vocabSize x k matrix). */ @Experimental +@Since("1.3.0") class LocalLDAModel private[clustering] ( - val topics: Matrix, - override val docConcentration: Vector, - override val topicConcentration: Double, + @Since("1.3.0") val topics: Matrix, + @Since("1.5.0") override val docConcentration: Vector, + @Since("1.5.0") override val topicConcentration: Double, override protected[clustering] val gammaShape: Double = 100) extends LDAModel with Serializable { @@ -376,6 +385,7 @@ class LocalLDAModel private[clustering] ( } @Experimental +@Since("1.5.0") object LocalLDAModel extends Loader[LocalLDAModel] { private object SaveLoadV1_0 { @@ -479,13 +489,14 @@ object LocalLDAModel extends Loader[LocalLDAModel] { * than the [[LocalLDAModel]]. */ @Experimental +@Since("1.3.0") class DistributedLDAModel private[clustering] ( private[clustering] val graph: Graph[LDA.TopicCounts, LDA.TokenCount], private[clustering] val globalTopicTotals: LDA.TopicCounts, - val k: Int, - val vocabSize: Int, - override val docConcentration: Vector, - override val topicConcentration: Double, + @Since("1.3.0") val k: Int, + @Since("1.3.0") val vocabSize: Int, + @Since("1.5.0") override val docConcentration: Vector, + @Since("1.5.0") override val topicConcentration: Double, private[spark] val iterationTimes: Array[Double], override protected[clustering] val gammaShape: Double = 100) extends LDAModel { @@ -603,6 +614,7 @@ class DistributedLDAModel private[clustering] ( * (term indices, topic indices). Note that terms will be omitted if not present in * the document. */ + @Since("1.5.0") lazy val topicAssignments: RDD[(Long, Array[Int], Array[Int])] = { // For reference, compare the below code with the core part of EMLDAOptimizer.next(). val eta = topicConcentration @@ -634,6 +646,7 @@ class DistributedLDAModel private[clustering] ( } /** Java-friendly version of [[topicAssignments]] */ + @Since("1.5.0") lazy val javaTopicAssignments: JavaRDD[(java.lang.Long, Array[Int], Array[Int])] = { topicAssignments.asInstanceOf[RDD[(java.lang.Long, Array[Int], Array[Int])]].toJavaRDD() } @@ -770,6 +783,7 @@ class DistributedLDAModel private[clustering] ( @Experimental +@Since("1.5.0") object DistributedLDAModel extends Loader[DistributedLDAModel] { private object SaveLoadV1_0 { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 396b36f2f6454..da234bdbb29e6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -42,9 +42,10 @@ import org.apache.spark.{Logging, SparkContext, SparkException} */ @Since("1.3.0") @Experimental -class PowerIterationClusteringModel( - val k: Int, - val assignments: RDD[PowerIterationClustering.Assignment]) extends Saveable with Serializable { +class PowerIterationClusteringModel @Since("1.3.0") ( + @Since("1.3.0") val k: Int, + @Since("1.3.0") val assignments: RDD[PowerIterationClustering.Assignment]) + extends Saveable with Serializable { @Since("1.4.0") override def save(sc: SparkContext, path: String): Unit = { @@ -56,6 +57,8 @@ class PowerIterationClusteringModel( @Since("1.4.0") object PowerIterationClusteringModel extends Loader[PowerIterationClusteringModel] { + + @Since("1.4.0") override def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { PowerIterationClusteringModel.SaveLoadV1_0.load(sc, path) } @@ -120,6 +123,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode * @see [[http://en.wikipedia.org/wiki/Spectral_clustering Spectral clustering (Wikipedia)]] */ @Experimental +@Since("1.3.0") class PowerIterationClustering private[clustering] ( private var k: Int, private var maxIterations: Int, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 41f2668ec6a7d..1d50ffec96faf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -66,9 +66,10 @@ import org.apache.spark.util.random.XORShiftRandom */ @Since("1.2.0") @Experimental -class StreamingKMeansModel( - override val clusterCenters: Array[Vector], - val clusterWeights: Array[Double]) extends KMeansModel(clusterCenters) with Logging { +class StreamingKMeansModel @Since("1.2.0") ( + @Since("1.2.0") override val clusterCenters: Array[Vector], + @Since("1.2.0") val clusterWeights: Array[Double]) + extends KMeansModel(clusterCenters) with Logging { /** * Perform a k-means update on a batch of data. @@ -168,10 +169,10 @@ class StreamingKMeansModel( */ @Since("1.2.0") @Experimental -class StreamingKMeans( - var k: Int, - var decayFactor: Double, - var timeUnit: String) extends Logging with Serializable { +class StreamingKMeans @Since("1.2.0") ( + @Since("1.2.0") var k: Int, + @Since("1.2.0") var decayFactor: Double, + @Since("1.2.0") var timeUnit: String) extends Logging with Serializable { @Since("1.2.0") def this() = this(2, 1.0, StreamingKMeans.BATCHES) From fb7e12fe2e14af8de4c206ca8096b2e8113bfddc Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 25 Aug 2015 22:35:49 -0700 Subject: [PATCH 088/802] [SPARK-10243] [MLLIB] update since versions in mllib.tree Same as #8421 but for `mllib.tree`. cc jkbradley Author: Xiangrui Meng Closes #8442 from mengxr/SPARK-10236. --- .../spark/mllib/tree/DecisionTree.scala | 3 +- .../mllib/tree/GradientBoostedTrees.scala | 2 +- .../spark/mllib/tree/configuration/Algo.scala | 2 ++ .../tree/configuration/BoostingStrategy.scala | 12 ++++---- .../tree/configuration/FeatureType.scala | 2 ++ .../tree/configuration/QuantileStrategy.scala | 2 ++ .../mllib/tree/configuration/Strategy.scala | 29 ++++++++++--------- .../mllib/tree/model/DecisionTreeModel.scala | 5 +++- .../apache/spark/mllib/tree/model/Node.scala | 18 ++++++------ .../spark/mllib/tree/model/Predict.scala | 6 ++-- .../apache/spark/mllib/tree/model/Split.scala | 8 ++--- .../mllib/tree/model/treeEnsembleModels.scala | 12 ++++---- 12 files changed, 57 insertions(+), 44 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 972841015d4f0..4a77d4adcd865 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -46,7 +46,8 @@ import org.apache.spark.util.random.XORShiftRandom */ @Since("1.0.0") @Experimental -class DecisionTree (private val strategy: Strategy) extends Serializable with Logging { +class DecisionTree @Since("1.0.0") (private val strategy: Strategy) + extends Serializable with Logging { strategy.assertValid() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index e750408600c33..95ed48cea6716 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -51,7 +51,7 @@ import org.apache.spark.storage.StorageLevel */ @Since("1.2.0") @Experimental -class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) +class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: BoostingStrategy) extends Serializable with Logging { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala index 8301ad160836b..853c7319ec44d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala @@ -26,7 +26,9 @@ import org.apache.spark.annotation.{Experimental, Since} @Since("1.0.0") @Experimental object Algo extends Enumeration { + @Since("1.0.0") type Algo = Value + @Since("1.0.0") val Classification, Regression = Value private[mllib] def fromString(name: String): Algo = name match { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index 7c569981977b4..b5c72fba3ede1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -41,14 +41,14 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} */ @Since("1.2.0") @Experimental -case class BoostingStrategy( +case class BoostingStrategy @Since("1.4.0") ( // Required boosting parameters - @BeanProperty var treeStrategy: Strategy, - @BeanProperty var loss: Loss, + @Since("1.2.0") @BeanProperty var treeStrategy: Strategy, + @Since("1.2.0") @BeanProperty var loss: Loss, // Optional boosting parameters - @BeanProperty var numIterations: Int = 100, - @BeanProperty var learningRate: Double = 0.1, - @BeanProperty var validationTol: Double = 1e-5) extends Serializable { + @Since("1.2.0") @BeanProperty var numIterations: Int = 100, + @Since("1.2.0") @BeanProperty var learningRate: Double = 0.1, + @Since("1.4.0") @BeanProperty var validationTol: Double = 1e-5) extends Serializable { /** * Check validity of parameters. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala index bb7c7ee4f964f..4e0cd473def06 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala @@ -26,6 +26,8 @@ import org.apache.spark.annotation.{Experimental, Since} @Since("1.0.0") @Experimental object FeatureType extends Enumeration { + @Since("1.0.0") type FeatureType = Value + @Since("1.0.0") val Continuous, Categorical = Value } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala index 904e42deebb5f..8262db8a4f111 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala @@ -26,6 +26,8 @@ import org.apache.spark.annotation.{Experimental, Since} @Since("1.0.0") @Experimental object QuantileStrategy extends Enumeration { + @Since("1.0.0") type QuantileStrategy = Value + @Since("1.0.0") val Sort, MinMax, ApproxHist = Value } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index b74e3f1f46523..89cc13b7c06cf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -69,20 +69,20 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ */ @Since("1.0.0") @Experimental -class Strategy ( - @BeanProperty var algo: Algo, - @BeanProperty var impurity: Impurity, - @BeanProperty var maxDepth: Int, - @BeanProperty var numClasses: Int = 2, - @BeanProperty var maxBins: Int = 32, - @BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort, - @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), - @BeanProperty var minInstancesPerNode: Int = 1, - @BeanProperty var minInfoGain: Double = 0.0, - @BeanProperty var maxMemoryInMB: Int = 256, - @BeanProperty var subsamplingRate: Double = 1, - @BeanProperty var useNodeIdCache: Boolean = false, - @BeanProperty var checkpointInterval: Int = 10) extends Serializable { +class Strategy @Since("1.3.0") ( + @Since("1.0.0") @BeanProperty var algo: Algo, + @Since("1.0.0") @BeanProperty var impurity: Impurity, + @Since("1.0.0") @BeanProperty var maxDepth: Int, + @Since("1.2.0") @BeanProperty var numClasses: Int = 2, + @Since("1.0.0") @BeanProperty var maxBins: Int = 32, + @Since("1.0.0") @BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort, + @Since("1.0.0") @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), + @Since("1.2.0") @BeanProperty var minInstancesPerNode: Int = 1, + @Since("1.2.0") @BeanProperty var minInfoGain: Double = 0.0, + @Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256, + @Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1, + @Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false, + @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10) extends Serializable { /** */ @@ -206,6 +206,7 @@ object Strategy { } @deprecated("Use Strategy.defaultStrategy instead.", "1.5.0") + @Since("1.2.0") def defaultStategy(algo: Algo): Strategy = defaultStrategy(algo) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 3eefd135f7836..e1bf23f4c34bb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -43,7 +43,9 @@ import org.apache.spark.util.Utils */ @Since("1.0.0") @Experimental -class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable with Saveable { +class DecisionTreeModel @Since("1.0.0") ( + @Since("1.0.0") val topNode: Node, + @Since("1.0.0") val algo: Algo) extends Serializable with Saveable { /** * Predict values for a single data point using the model trained. @@ -110,6 +112,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable /** * Print the full model to a string. */ + @Since("1.2.0") def toDebugString: String = { val header = toString + "\n" header + topNode.subtreeToString(2) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 8c54c55107233..ea6e5aa5d94e7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -41,15 +41,15 @@ import org.apache.spark.mllib.linalg.Vector */ @Since("1.0.0") @DeveloperApi -class Node ( - val id: Int, - var predict: Predict, - var impurity: Double, - var isLeaf: Boolean, - var split: Option[Split], - var leftNode: Option[Node], - var rightNode: Option[Node], - var stats: Option[InformationGainStats]) extends Serializable with Logging { +class Node @Since("1.2.0") ( + @Since("1.0.0") val id: Int, + @Since("1.0.0") var predict: Predict, + @Since("1.2.0") var impurity: Double, + @Since("1.0.0") var isLeaf: Boolean, + @Since("1.0.0") var split: Option[Split], + @Since("1.0.0") var leftNode: Option[Node], + @Since("1.0.0") var rightNode: Option[Node], + @Since("1.0.0") var stats: Option[InformationGainStats]) extends Serializable with Logging { override def toString: String = { s"id = $id, isLeaf = $isLeaf, predict = $predict, impurity = $impurity, " + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala index 965784051ede5..06ceff19d8633 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -26,9 +26,9 @@ import org.apache.spark.annotation.{DeveloperApi, Since} */ @Since("1.2.0") @DeveloperApi -class Predict( - val predict: Double, - val prob: Double = 0.0) extends Serializable { +class Predict @Since("1.2.0") ( + @Since("1.2.0") val predict: Double, + @Since("1.2.0") val prob: Double = 0.0) extends Serializable { override def toString: String = s"$predict (prob = $prob)" diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index 45db83ae3a1f3..b85a66c05a81d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -34,10 +34,10 @@ import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType @Since("1.0.0") @DeveloperApi case class Split( - feature: Int, - threshold: Double, - featureType: FeatureType, - categories: List[Double]) { + @Since("1.0.0") feature: Int, + @Since("1.0.0") threshold: Double, + @Since("1.0.0") featureType: FeatureType, + @Since("1.0.0") categories: List[Double]) { override def toString: String = { s"Feature = $feature, threshold = $threshold, featureType = $featureType, " + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 19571447a2c56..df5b8feab5d5d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -48,7 +48,9 @@ import org.apache.spark.util.Utils */ @Since("1.2.0") @Experimental -class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel]) +class RandomForestModel @Since("1.2.0") ( + @Since("1.2.0") override val algo: Algo, + @Since("1.2.0") override val trees: Array[DecisionTreeModel]) extends TreeEnsembleModel(algo, trees, Array.fill(trees.length)(1.0), combiningStrategy = if (algo == Classification) Vote else Average) with Saveable { @@ -115,10 +117,10 @@ object RandomForestModel extends Loader[RandomForestModel] { */ @Since("1.2.0") @Experimental -class GradientBoostedTreesModel( - override val algo: Algo, - override val trees: Array[DecisionTreeModel], - override val treeWeights: Array[Double]) +class GradientBoostedTreesModel @Since("1.2.0") ( + @Since("1.2.0") override val algo: Algo, + @Since("1.2.0") override val trees: Array[DecisionTreeModel], + @Since("1.2.0") override val treeWeights: Array[Double]) extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) with Saveable { From 4657fa1f37d41dd4c7240a960342b68c7c591f48 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 25 Aug 2015 22:49:33 -0700 Subject: [PATCH 089/802] [SPARK-10235] [MLLIB] update since versions in mllib.regression Same as #8421 but for `mllib.regression`. cc freeman-lab dbtsai Author: Xiangrui Meng Closes #8426 from mengxr/SPARK-10235 and squashes the following commits: 6cd28e4 [Xiangrui Meng] update since versions in mllib.regression --- .../regression/GeneralizedLinearAlgorithm.scala | 6 ++++-- .../mllib/regression/IsotonicRegression.scala | 16 +++++++++------- .../spark/mllib/regression/LabeledPoint.scala | 5 +++-- .../apache/spark/mllib/regression/Lasso.scala | 9 ++++++--- .../mllib/regression/LinearRegression.scala | 9 ++++++--- .../spark/mllib/regression/RidgeRegression.scala | 12 +++++++----- .../regression/StreamingLinearAlgorithm.scala | 8 +++----- .../StreamingLinearRegressionWithSGD.scala | 11 +++++++++-- 8 files changed, 47 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 509f6a2d169c4..7e3b4d5648fe3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -38,7 +38,9 @@ import org.apache.spark.storage.StorageLevel */ @Since("0.8.0") @DeveloperApi -abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double) +abstract class GeneralizedLinearModel @Since("1.0.0") ( + @Since("1.0.0") val weights: Vector, + @Since("0.8.0") val intercept: Double) extends Serializable { /** @@ -107,7 +109,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * The optimizer to solve the problem. * */ - @Since("1.0.0") + @Since("0.8.0") def optimizer: Optimizer /** Whether to add intercept (default: false). */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index 31ca7c2f207d9..877d31ba41303 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -50,10 +50,10 @@ import org.apache.spark.sql.SQLContext */ @Since("1.3.0") @Experimental -class IsotonicRegressionModel ( - val boundaries: Array[Double], - val predictions: Array[Double], - val isotonic: Boolean) extends Serializable with Saveable { +class IsotonicRegressionModel @Since("1.3.0") ( + @Since("1.3.0") val boundaries: Array[Double], + @Since("1.3.0") val predictions: Array[Double], + @Since("1.3.0") val isotonic: Boolean) extends Serializable with Saveable { private val predictionOrd = if (isotonic) Ordering[Double] else Ordering[Double].reverse @@ -63,7 +63,6 @@ class IsotonicRegressionModel ( /** * A Java-friendly constructor that takes two Iterable parameters and one Boolean parameter. - * */ @Since("1.4.0") def this(boundaries: java.lang.Iterable[Double], @@ -214,8 +213,6 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { } } - /** - */ @Since("1.4.0") override def load(sc: SparkContext, path: String): IsotonicRegressionModel = { implicit val formats = DefaultFormats @@ -256,6 +253,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { * @see [[http://en.wikipedia.org/wiki/Isotonic_regression Isotonic regression (Wikipedia)]] */ @Experimental +@Since("1.3.0") class IsotonicRegression private (private var isotonic: Boolean) extends Serializable { /** @@ -263,6 +261,7 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali * * @return New instance of IsotonicRegression. */ + @Since("1.3.0") def this() = this(true) /** @@ -271,6 +270,7 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali * @param isotonic Isotonic (increasing) or antitonic (decreasing) sequence. * @return This instance of IsotonicRegression. */ + @Since("1.3.0") def setIsotonic(isotonic: Boolean): this.type = { this.isotonic = isotonic this @@ -286,6 +286,7 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali * the algorithm is executed. * @return Isotonic regression model. */ + @Since("1.3.0") def run(input: RDD[(Double, Double, Double)]): IsotonicRegressionModel = { val preprocessedInput = if (isotonic) { input @@ -311,6 +312,7 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali * the algorithm is executed. * @return Isotonic regression model. */ + @Since("1.3.0") def run(input: JavaRDD[(JDouble, JDouble, JDouble)]): IsotonicRegressionModel = { run(input.rdd.retag.asInstanceOf[RDD[(Double, Double, Double)]]) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index f7fe1b7b21fca..c284ad2325374 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -29,11 +29,12 @@ import org.apache.spark.SparkException * * @param label Label for this data point. * @param features List of features for this data point. - * */ @Since("0.8.0") @BeanInfo -case class LabeledPoint(label: Double, features: Vector) { +case class LabeledPoint @Since("1.0.0") ( + @Since("0.8.0") label: Double, + @Since("1.0.0") features: Vector) { override def toString: String = { s"($label,$features)" } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index 556411a366bd2..a9aba173fa0e3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -34,9 +34,9 @@ import org.apache.spark.rdd.RDD * */ @Since("0.8.0") -class LassoModel ( - override val weights: Vector, - override val intercept: Double) +class LassoModel @Since("1.1.0") ( + @Since("1.0.0") override val weights: Vector, + @Since("0.8.0") override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable with Saveable with PMMLExportable { @@ -84,6 +84,7 @@ object LassoModel extends Loader[LassoModel] { * its corresponding right hand side label y. * See also the documentation for the precise formulation. */ +@Since("0.8.0") class LassoWithSGD private ( private var stepSize: Double, private var numIterations: Int, @@ -93,6 +94,7 @@ class LassoWithSGD private ( private val gradient = new LeastSquaresGradient() private val updater = new L1Updater() + @Since("0.8.0") override val optimizer = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) @@ -103,6 +105,7 @@ class LassoWithSGD private ( * Construct a Lasso object with default parameters: {stepSize: 1.0, numIterations: 100, * regParam: 0.01, miniBatchFraction: 1.0}. */ + @Since("0.8.0") def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 00ab06e3ba738..4996ace5df85d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -34,9 +34,9 @@ import org.apache.spark.rdd.RDD * */ @Since("0.8.0") -class LinearRegressionModel ( - override val weights: Vector, - override val intercept: Double) +class LinearRegressionModel @Since("1.1.0") ( + @Since("1.0.0") override val weights: Vector, + @Since("0.8.0") override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable with Saveable with PMMLExportable { @@ -85,6 +85,7 @@ object LinearRegressionModel extends Loader[LinearRegressionModel] { * its corresponding right hand side label y. * See also the documentation for the precise formulation. */ +@Since("0.8.0") class LinearRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, @@ -93,6 +94,7 @@ class LinearRegressionWithSGD private[mllib] ( private val gradient = new LeastSquaresGradient() private val updater = new SimpleUpdater() + @Since("0.8.0") override val optimizer = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) @@ -102,6 +104,7 @@ class LinearRegressionWithSGD private[mllib] ( * Construct a LinearRegression object with default parameters: {stepSize: 1.0, * numIterations: 100, miniBatchFraction: 1.0}. */ + @Since("0.8.0") def this() = this(1.0, 100, 1.0) override protected[mllib] def createModel(weights: Vector, intercept: Double) = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 21a791d98b2cb..0a44ff559d55b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -35,9 +35,9 @@ import org.apache.spark.rdd.RDD * */ @Since("0.8.0") -class RidgeRegressionModel ( - override val weights: Vector, - override val intercept: Double) +class RidgeRegressionModel @Since("1.1.0") ( + @Since("1.0.0") override val weights: Vector, + @Since("0.8.0") override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable with Saveable with PMMLExportable { @@ -85,6 +85,7 @@ object RidgeRegressionModel extends Loader[RidgeRegressionModel] { * its corresponding right hand side label y. * See also the documentation for the precise formulation. */ +@Since("0.8.0") class RidgeRegressionWithSGD private ( private var stepSize: Double, private var numIterations: Int, @@ -94,7 +95,7 @@ class RidgeRegressionWithSGD private ( private val gradient = new LeastSquaresGradient() private val updater = new SquaredL2Updater() - + @Since("0.8.0") override val optimizer = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) @@ -105,6 +106,7 @@ class RidgeRegressionWithSGD private ( * Construct a RidgeRegression object with default parameters: {stepSize: 1.0, numIterations: 100, * regParam: 0.01, miniBatchFraction: 1.0}. */ + @Since("0.8.0") def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { @@ -134,7 +136,7 @@ object RidgeRegressionWithSGD { * the number of features in the data. * */ - @Since("0.8.0") + @Since("1.0.0") def train( input: RDD[LabeledPoint], numIterations: Int, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index cd3ed8a1549db..73948b2d9851a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaSparkContext.fakeClassTag -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.streaming.api.java.{JavaDStream, JavaPairDStream} import org.apache.spark.streaming.dstream.DStream @@ -83,9 +83,8 @@ abstract class StreamingLinearAlgorithm[ * batch of data from the stream. * * @param data DStream containing labeled data - * */ - @Since("1.3.0") + @Since("1.1.0") def trainOn(data: DStream[LabeledPoint]): Unit = { if (model.isEmpty) { throw new IllegalArgumentException("Model must be initialized before starting training.") @@ -105,7 +104,6 @@ abstract class StreamingLinearAlgorithm[ /** * Java-friendly version of `trainOn`. - * */ @Since("1.3.0") def trainOn(data: JavaDStream[LabeledPoint]): Unit = trainOn(data.dstream) @@ -129,7 +127,7 @@ abstract class StreamingLinearAlgorithm[ * Java-friendly version of `predictOn`. * */ - @Since("1.1.0") + @Since("1.3.0") def predictOn(data: JavaDStream[Vector]): JavaDStream[java.lang.Double] = { JavaDStream.fromDStream(predictOn(data.dstream).asInstanceOf[DStream[java.lang.Double]]) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index 26654e4a06838..fe1d487cdd078 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.mllib.linalg.Vector /** @@ -41,6 +41,7 @@ import org.apache.spark.mllib.linalg.Vector * .trainOn(DStream) */ @Experimental +@Since("1.1.0") class StreamingLinearRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, @@ -54,8 +55,10 @@ class StreamingLinearRegressionWithSGD private[mllib] ( * Initial weights must be set before using trainOn or predictOn * (see `StreamingLinearAlgorithm`) */ + @Since("1.1.0") def this() = this(0.1, 50, 1.0) + @Since("1.1.0") val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction) protected var model: Option[LinearRegressionModel] = None @@ -63,6 +66,7 @@ class StreamingLinearRegressionWithSGD private[mllib] ( /** * Set the step size for gradient descent. Default: 0.1. */ + @Since("1.1.0") def setStepSize(stepSize: Double): this.type = { this.algorithm.optimizer.setStepSize(stepSize) this @@ -71,6 +75,7 @@ class StreamingLinearRegressionWithSGD private[mllib] ( /** * Set the number of iterations of gradient descent to run per update. Default: 50. */ + @Since("1.1.0") def setNumIterations(numIterations: Int): this.type = { this.algorithm.optimizer.setNumIterations(numIterations) this @@ -79,6 +84,7 @@ class StreamingLinearRegressionWithSGD private[mllib] ( /** * Set the fraction of each batch to use for updates. Default: 1.0. */ + @Since("1.1.0") def setMiniBatchFraction(miniBatchFraction: Double): this.type = { this.algorithm.optimizer.setMiniBatchFraction(miniBatchFraction) this @@ -87,6 +93,7 @@ class StreamingLinearRegressionWithSGD private[mllib] ( /** * Set the initial weights. */ + @Since("1.1.0") def setInitialWeights(initialWeights: Vector): this.type = { this.model = Some(algorithm.createModel(initialWeights, 0.0)) this @@ -95,9 +102,9 @@ class StreamingLinearRegressionWithSGD private[mllib] ( /** * Set the convergence tolerance. Default: 0.001. */ + @Since("1.5.0") def setConvergenceTol(tolerance: Double): this.type = { this.algorithm.optimizer.setConvergenceTol(tolerance) this } - } From 321d7759691bed9867b1f0470f12eab2faa50aff Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 25 Aug 2015 23:45:41 -0700 Subject: [PATCH 090/802] [SPARK-10236] [MLLIB] update since versions in mllib.feature Same as #8421 but for `mllib.feature`. cc dbtsai Author: Xiangrui Meng Closes #8449 from mengxr/SPARK-10236.feature and squashes the following commits: 0e8d658 [Xiangrui Meng] remove unnecessary comment ad70b03 [Xiangrui Meng] update since versions in mllib.feature --- .../mllib/clustering/PowerIterationClustering.scala | 2 -- .../apache/spark/mllib/feature/ChiSqSelector.scala | 4 ++-- .../spark/mllib/feature/ElementwiseProduct.scala | 3 ++- .../scala/org/apache/spark/mllib/feature/IDF.scala | 6 ++++-- .../org/apache/spark/mllib/feature/Normalizer.scala | 2 +- .../scala/org/apache/spark/mllib/feature/PCA.scala | 7 +++++-- .../apache/spark/mllib/feature/StandardScaler.scala | 12 ++++++------ .../org/apache/spark/mllib/feature/Word2Vec.scala | 1 + 8 files changed, 21 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index da234bdbb29e6..6c76e26fd1626 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -71,8 +71,6 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode private[clustering] val thisClassName = "org.apache.spark.mllib.clustering.PowerIterationClusteringModel" - /** - */ @Since("1.4.0") def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = { val sqlContext = new SQLContext(sc) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index fdd974d7a391e..4743cfd1a2c3f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD */ @Since("1.3.0") @Experimental -class ChiSqSelectorModel ( +class ChiSqSelectorModel @Since("1.3.0") ( @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer { require(isSorted(selectedFeatures), "Array has to be sorted asc") @@ -112,7 +112,7 @@ class ChiSqSelectorModel ( */ @Since("1.3.0") @Experimental -class ChiSqSelector ( +class ChiSqSelector @Since("1.3.0") ( @Since("1.3.0") val numTopFeatures: Int) extends Serializable { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala index 33e2d17bb472e..d0a6cf61687a8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala @@ -29,7 +29,8 @@ import org.apache.spark.mllib.linalg._ */ @Since("1.4.0") @Experimental -class ElementwiseProduct(val scalingVec: Vector) extends VectorTransformer { +class ElementwiseProduct @Since("1.4.0") ( + @Since("1.4.0") val scalingVec: Vector) extends VectorTransformer { /** * Does the hadamard product transformation. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index d5353ddd972e0..68078ccfa3d60 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -39,8 +39,9 @@ import org.apache.spark.rdd.RDD */ @Since("1.1.0") @Experimental -class IDF(val minDocFreq: Int) { +class IDF @Since("1.2.0") (@Since("1.2.0") val minDocFreq: Int) { + @Since("1.1.0") def this() = this(0) // TODO: Allow different IDF formulations. @@ -162,7 +163,8 @@ private object IDF { * Represents an IDF model that can transform term frequency vectors. */ @Experimental -class IDFModel private[spark] (val idf: Vector) extends Serializable { +@Since("1.1.0") +class IDFModel private[spark] (@Since("1.1.0") val idf: Vector) extends Serializable { /** * Transforms term frequency (TF) vectors to TF-IDF vectors. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala index 0e070257d9fb2..8d5a22520d6b8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -33,7 +33,7 @@ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors */ @Since("1.1.0") @Experimental -class Normalizer(p: Double) extends VectorTransformer { +class Normalizer @Since("1.1.0") (p: Double) extends VectorTransformer { @Since("1.1.0") def this() = this(2) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala index a48b7bba665d7..ecb3c1e6c1c83 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala @@ -29,7 +29,7 @@ import org.apache.spark.rdd.RDD * @param k number of principal components */ @Since("1.4.0") -class PCA(val k: Int) { +class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) { require(k >= 1, s"PCA requires a number of principal components k >= 1 but was given $k") /** @@ -74,7 +74,10 @@ class PCA(val k: Int) { * @param k number of principal components. * @param pc a principal components Matrix. Each column is one principal component. */ -class PCAModel private[spark] (val k: Int, val pc: DenseMatrix) extends VectorTransformer { +@Since("1.4.0") +class PCAModel private[spark] ( + @Since("1.4.0") val k: Int, + @Since("1.4.0") val pc: DenseMatrix) extends VectorTransformer { /** * Transform a vector by computed Principal Components. * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index b95d5a899001e..f018b453bae7e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -34,7 +34,7 @@ import org.apache.spark.rdd.RDD */ @Since("1.1.0") @Experimental -class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { +class StandardScaler @Since("1.1.0") (withMean: Boolean, withStd: Boolean) extends Logging { @Since("1.1.0") def this() = this(false, true) @@ -74,11 +74,11 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { */ @Since("1.1.0") @Experimental -class StandardScalerModel ( - val std: Vector, - val mean: Vector, - var withStd: Boolean, - var withMean: Boolean) extends VectorTransformer { +class StandardScalerModel @Since("1.3.0") ( + @Since("1.3.0") val std: Vector, + @Since("1.1.0") val mean: Vector, + @Since("1.3.0") var withStd: Boolean, + @Since("1.3.0") var withMean: Boolean) extends VectorTransformer { /** */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index e6f45ae4b01d5..36b124c5d2966 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -436,6 +436,7 @@ class Word2Vec extends Serializable with Logging { * (i * vectorSize, i * vectorSize + vectorSize) */ @Experimental +@Since("1.1.0") class Word2VecModel private[mllib] ( private val wordIndex: Map[String, Int], private val wordVectors: Array[Float]) extends Serializable with Saveable { From 75d4773aa50e24972c533e8b48697fde586429eb Mon Sep 17 00:00:00 2001 From: felixcheung Date: Tue, 25 Aug 2015 23:48:16 -0700 Subject: [PATCH 091/802] [SPARK-9316] [SPARKR] Add support for filtering using `[` (synonym for filter / select) Add support for ``` df[df$name == "Smith", c(1,2)] df[df$age %in% c(19, 30), 1:2] ``` shivaram Author: felixcheung Closes #8394 from felixcheung/rsubset. --- R/pkg/R/DataFrame.R | 22 +++++++++++++++++++++- R/pkg/inst/tests/test_sparkSQL.R | 27 +++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index ae1d912cf6da1..a5162de705f8f 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -985,9 +985,11 @@ setMethod("$<-", signature(x = "DataFrame"), x }) +setClassUnion("numericOrcharacter", c("numeric", "character")) + #' @rdname select #' @name [[ -setMethod("[[", signature(x = "DataFrame"), +setMethod("[[", signature(x = "DataFrame", i = "numericOrcharacter"), function(x, i) { if (is.numeric(i)) { cols <- columns(x) @@ -1010,6 +1012,20 @@ setMethod("[", signature(x = "DataFrame", i = "missing"), select(x, j) }) +#' @rdname select +#' @name [ +setMethod("[", signature(x = "DataFrame", i = "Column"), + function(x, i, j, ...) { + # It could handle i as "character" but it seems confusing and not required + # https://stat.ethz.ch/R-manual/R-devel/library/base/html/Extract.data.frame.html + filtered <- filter(x, i) + if (!missing(j)) { + filtered[, j] + } else { + filtered + } + }) + #' Select #' #' Selects a set of columns with names or Column expressions. @@ -1028,8 +1044,12 @@ setMethod("[", signature(x = "DataFrame", i = "missing"), #' # Columns can also be selected using `[[` and `[` #' df[[2]] == df[["age"]] #' df[,2] == df[,"age"] +#' df[,c("name", "age")] #' # Similar to R data frames columns can also be selected using `$` #' df$age +#' # It can also be subset on rows and Columns +#' df[df$name == "Smith", c(1,2)] +#' df[df$age %in% c(19, 30), 1:2] #' } setMethod("select", signature(x = "DataFrame", col = "character"), function(x, col, ...) { diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 556b8c5447054..ee48a3dc0cc05 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -587,6 +587,33 @@ test_that("select with column", { expect_equal(collect(select(df3, "x"))[[1, 1]], "x") }) +test_that("subsetting", { + # jsonFile returns columns in random order + df <- select(jsonFile(sqlContext, jsonPath), "name", "age") + filtered <- df[df$age > 20,] + expect_equal(count(filtered), 1) + expect_equal(columns(filtered), c("name", "age")) + expect_equal(collect(filtered)$name, "Andy") + + df2 <- df[df$age == 19, 1] + expect_is(df2, "DataFrame") + expect_equal(count(df2), 1) + expect_equal(columns(df2), c("name")) + expect_equal(collect(df2)$name, "Justin") + + df3 <- df[df$age > 20, 2] + expect_equal(count(df3), 1) + expect_equal(columns(df3), c("age")) + + df4 <- df[df$age %in% c(19, 30), 1:2] + expect_equal(count(df4), 2) + expect_equal(columns(df4), c("name", "age")) + + df5 <- df[df$age %in% c(19), c(1,2)] + expect_equal(count(df5), 1) + expect_equal(columns(df5), c("name", "age")) +}) + test_that("selectExpr() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) selected <- selectExpr(df, "age * 2") From bb1640529725c6c38103b95af004f8bd90eeee5c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 26 Aug 2015 00:37:04 -0700 Subject: [PATCH 092/802] Closes #8443 From 6519fd06cc8175c9182ef16cf8a37d7f255eb846 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 26 Aug 2015 11:47:05 -0700 Subject: [PATCH 093/802] [SPARK-9665] [MLLIB] audit MLlib API annotations I only found `ml.NaiveBayes` missing `Experimental` annotation. This PR doesn't cover Python APIs. cc jkbradley Author: Xiangrui Meng Closes #8452 from mengxr/SPARK-9665. --- .../apache/spark/ml/classification/NaiveBayes.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 97cbaf1fa8761..69cb88a7e6718 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -18,11 +18,11 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkException -import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} -import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param, DoubleParam} +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.util.Identifiable -import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes} -import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel} +import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes, NaiveBayesModel => OldNaiveBayesModel} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD @@ -59,6 +59,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams { } /** + * :: Experimental :: * Naive Bayes Classifiers. * It supports both Multinomial NB * ([[http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html]]) @@ -68,6 +69,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams { * ([[http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html]]). * The input feature values must be nonnegative. */ +@Experimental class NaiveBayes(override val uid: String) extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] with NaiveBayesParams { @@ -101,11 +103,13 @@ class NaiveBayes(override val uid: String) } /** + * :: Experimental :: * Model produced by [[NaiveBayes]] * @param pi log of class priors, whose dimension is C (number of classes) * @param theta log of class conditional probabilities, whose dimension is C (number of classes) * by D (number of features) */ +@Experimental class NaiveBayesModel private[ml] ( override val uid: String, val pi: Vector, From de7209c256aaf79a0978cfcf6e98bb013267b93a Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 26 Aug 2015 12:19:36 -0700 Subject: [PATCH 094/802] HOTFIX: Increase PRB timeout --- dev/run-tests-jenkins | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index c4d39d95d5890..f144c053046c5 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -48,8 +48,8 @@ COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}" SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}" # format: http://linux.die.net/man/1/timeout -# must be less than the timeout configured on Jenkins (currently 180m) -TESTS_TIMEOUT="175m" +# must be less than the timeout configured on Jenkins (currently 300m) +TESTS_TIMEOUT="250m" # Array to capture all tests to run on the pull request. These tests are held under the #+ dev/tests/ directory. From 086d4681df3ebfccfc04188262c10482f44553b0 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 26 Aug 2015 14:02:19 -0700 Subject: [PATCH 095/802] [SPARK-10241] [MLLIB] update since versions in mllib.recommendation Same as #8421 but for `mllib.recommendation`. cc srowen coderxiang Author: Xiangrui Meng Closes #8432 from mengxr/SPARK-10241. --- .../spark/mllib/recommendation/ALS.scala | 22 ++++++++++++++++++- .../MatrixFactorizationModel.scala | 8 +++---- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index b27ef1b949e2e..33aaf853e599d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -28,7 +28,10 @@ import org.apache.spark.storage.StorageLevel * A more compact class to represent a rating than Tuple3[Int, Int, Double]. */ @Since("0.8.0") -case class Rating(user: Int, product: Int, rating: Double) +case class Rating @Since("0.8.0") ( + @Since("0.8.0") user: Int, + @Since("0.8.0") product: Int, + @Since("0.8.0") rating: Double) /** * Alternating Least Squares matrix factorization. @@ -59,6 +62,7 @@ case class Rating(user: Int, product: Int, rating: Double) * indicated user * preferences rather than explicit ratings given to items. */ +@Since("0.8.0") class ALS private ( private var numUserBlocks: Int, private var numProductBlocks: Int, @@ -74,6 +78,7 @@ class ALS private ( * Constructs an ALS instance with default parameters: {numBlocks: -1, rank: 10, iterations: 10, * lambda: 0.01, implicitPrefs: false, alpha: 1.0}. */ + @Since("0.8.0") def this() = this(-1, -1, 10, 10, 0.01, false, 1.0) /** If true, do alternating nonnegative least squares. */ @@ -90,6 +95,7 @@ class ALS private ( * Set the number of blocks for both user blocks and product blocks to parallelize the computation * into; pass -1 for an auto-configured number of blocks. Default: -1. */ + @Since("0.8.0") def setBlocks(numBlocks: Int): this.type = { this.numUserBlocks = numBlocks this.numProductBlocks = numBlocks @@ -99,6 +105,7 @@ class ALS private ( /** * Set the number of user blocks to parallelize the computation. */ + @Since("1.1.0") def setUserBlocks(numUserBlocks: Int): this.type = { this.numUserBlocks = numUserBlocks this @@ -107,30 +114,35 @@ class ALS private ( /** * Set the number of product blocks to parallelize the computation. */ + @Since("1.1.0") def setProductBlocks(numProductBlocks: Int): this.type = { this.numProductBlocks = numProductBlocks this } /** Set the rank of the feature matrices computed (number of features). Default: 10. */ + @Since("0.8.0") def setRank(rank: Int): this.type = { this.rank = rank this } /** Set the number of iterations to run. Default: 10. */ + @Since("0.8.0") def setIterations(iterations: Int): this.type = { this.iterations = iterations this } /** Set the regularization parameter, lambda. Default: 0.01. */ + @Since("0.8.0") def setLambda(lambda: Double): this.type = { this.lambda = lambda this } /** Sets whether to use implicit preference. Default: false. */ + @Since("0.8.1") def setImplicitPrefs(implicitPrefs: Boolean): this.type = { this.implicitPrefs = implicitPrefs this @@ -139,12 +151,14 @@ class ALS private ( /** * Sets the constant used in computing confidence in implicit ALS. Default: 1.0. */ + @Since("0.8.1") def setAlpha(alpha: Double): this.type = { this.alpha = alpha this } /** Sets a random seed to have deterministic results. */ + @Since("1.0.0") def setSeed(seed: Long): this.type = { this.seed = seed this @@ -154,6 +168,7 @@ class ALS private ( * Set whether the least-squares problems solved at each iteration should have * nonnegativity constraints. */ + @Since("1.1.0") def setNonnegative(b: Boolean): this.type = { this.nonnegative = b this @@ -166,6 +181,7 @@ class ALS private ( * set `spark.rdd.compress` to `true` to reduce the space requirement, at the cost of speed. */ @DeveloperApi + @Since("1.1.0") def setIntermediateRDDStorageLevel(storageLevel: StorageLevel): this.type = { require(storageLevel != StorageLevel.NONE, "ALS is not designed to run without persisting intermediate RDDs.") @@ -181,6 +197,7 @@ class ALS private ( * at the cost of speed. */ @DeveloperApi + @Since("1.3.0") def setFinalRDDStorageLevel(storageLevel: StorageLevel): this.type = { this.finalRDDStorageLevel = storageLevel this @@ -194,6 +211,7 @@ class ALS private ( * this setting is ignored. */ @DeveloperApi + @Since("1.4.0") def setCheckpointInterval(checkpointInterval: Int): this.type = { this.checkpointInterval = checkpointInterval this @@ -203,6 +221,7 @@ class ALS private ( * Run ALS with the configured parameters on an input RDD of (user, product, rating) triples. * Returns a MatrixFactorizationModel with feature vectors for each user and product. */ + @Since("0.8.0") def run(ratings: RDD[Rating]): MatrixFactorizationModel = { val sc = ratings.context @@ -250,6 +269,7 @@ class ALS private ( /** * Java-friendly version of [[ALS.run]]. */ + @Since("1.3.0") def run(ratings: JavaRDD[Rating]): MatrixFactorizationModel = run(ratings.rdd) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index ba4cfdcd9f1dd..46562eb2ad0f7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -52,10 +52,10 @@ import org.apache.spark.storage.StorageLevel * and the features computed for this product. */ @Since("0.8.0") -class MatrixFactorizationModel( - val rank: Int, - val userFeatures: RDD[(Int, Array[Double])], - val productFeatures: RDD[(Int, Array[Double])]) +class MatrixFactorizationModel @Since("0.8.0") ( + @Since("0.8.0") val rank: Int, + @Since("0.8.0") val userFeatures: RDD[(Int, Array[Double])], + @Since("0.8.0") val productFeatures: RDD[(Int, Array[Double])]) extends Saveable with Serializable with Logging { require(rank > 0) From d41d6c48207159490c1e1d9cc54015725cfa41b2 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 26 Aug 2015 16:04:44 -0700 Subject: [PATCH 096/802] [SPARK-10305] [SQL] fix create DataFrame from Python class cc jkbradley Author: Davies Liu Closes #8470 from davies/fix_create_df. --- python/pyspark/sql/tests.py | 12 ++++++++++++ python/pyspark/sql/types.py | 6 ++++++ 2 files changed, 18 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index aacfb34c77618..cd32e26c64f22 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -145,6 +145,12 @@ class PythonOnlyPoint(ExamplePoint): __UDT__ = PythonOnlyUDT() +class MyObject(object): + def __init__(self, key, value): + self.key = key + self.value = value + + class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 def test_data_type_eq(self): @@ -383,6 +389,12 @@ def test_infer_nested_schema(self): df = self.sqlCtx.inferSchema(rdd) self.assertEquals(Row(field1=1, field2=u'row1'), df.first()) + def test_create_dataframe_from_objects(self): + data = [MyObject(1, "1"), MyObject(2, "2")] + df = self.sqlCtx.createDataFrame(data) + self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")]) + self.assertEqual(df.first(), Row(key=1, value="1")) + def test_select_null_literal(self): df = self.sqlCtx.sql("select null as col") self.assertEquals(Row(col=None), df.first()) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index ed4e5b594bd61..94e581a78364c 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -537,6 +537,9 @@ def toInternal(self, obj): return tuple(f.toInternal(obj.get(n)) for n, f in zip(self.names, self.fields)) elif isinstance(obj, (tuple, list)): return tuple(f.toInternal(v) for f, v in zip(self.fields, obj)) + elif hasattr(obj, "__dict__"): + d = obj.__dict__ + return tuple(f.toInternal(d.get(n)) for n, f in zip(self.names, self.fields)) else: raise ValueError("Unexpected tuple %r with StructType" % obj) else: @@ -544,6 +547,9 @@ def toInternal(self, obj): return tuple(obj.get(n) for n in self.names) elif isinstance(obj, (list, tuple)): return tuple(obj) + elif hasattr(obj, "__dict__"): + d = obj.__dict__ + return tuple(d.get(n) for n in self.names) else: raise ValueError("Unexpected tuple %r with StructType" % obj) From ad7f0f160be096c0fdae6e6cf7e3b6ba4a606de7 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Wed, 26 Aug 2015 18:13:07 -0700 Subject: [PATCH 097/802] [SPARK-10308] [SPARKR] Add %in% to the exported namespace I also checked all the other functions defined in column.R, functions.R and DataFrame.R and everything else looked fine. cc yu-iskw Author: Shivaram Venkataraman Closes #8473 from shivaram/in-namespace. --- R/pkg/NAMESPACE | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 3e5c89d779b7b..5286c01986204 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -47,12 +47,12 @@ exportMethods("arrange", "join", "limit", "merge", + "mutate", + "na.omit", "names", "ncol", "nrow", "orderBy", - "mutate", - "names", "persist", "printSchema", "rbind", @@ -82,7 +82,8 @@ exportMethods("arrange", exportClasses("Column") -exportMethods("abs", +exportMethods("%in%", + "abs", "acos", "add_months", "alias", From 773ca037a43d464ce7f16fe693ca6034f09a35b7 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Wed, 26 Aug 2015 18:14:32 -0700 Subject: [PATCH 098/802] [MINOR] [SPARKR] Fix some validation problems in SparkR Getting rid of some validation problems in SparkR https://github.com/apache/spark/pull/7883 cc shivaram ``` inst/tests/test_Serde.R:26:1: style: Trailing whitespace is superfluous. ^~ inst/tests/test_Serde.R:34:1: style: Trailing whitespace is superfluous. ^~ inst/tests/test_Serde.R:37:38: style: Trailing whitespace is superfluous. expect_equal(class(x), "character") ^~ inst/tests/test_Serde.R:50:1: style: Trailing whitespace is superfluous. ^~ inst/tests/test_Serde.R:55:1: style: Trailing whitespace is superfluous. ^~ inst/tests/test_Serde.R:60:1: style: Trailing whitespace is superfluous. ^~ inst/tests/test_sparkSQL.R:611:1: style: Trailing whitespace is superfluous. ^~ R/DataFrame.R:664:1: style: Trailing whitespace is superfluous. ^~~~~~~~~~~~~~ R/DataFrame.R:670:55: style: Trailing whitespace is superfluous. df <- data.frame(row.names = 1 : nrow) ^~~~~~~~~~~~~~~~ R/DataFrame.R:672:1: style: Trailing whitespace is superfluous. ^~~~~~~~~~~~~~ R/DataFrame.R:686:49: style: Trailing whitespace is superfluous. df[[names[colIndex]]] <- vec ^~~~~~~~~~~~~~~~~~ ``` Author: Yu ISHIKAWA Closes #8474 from yu-iskw/minor-fix-sparkr. --- R/pkg/R/DataFrame.R | 8 ++++---- R/pkg/inst/tests/test_Serde.R | 12 ++++++------ R/pkg/inst/tests/test_sparkSQL.R | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index a5162de705f8f..dd8126aebf467 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -661,15 +661,15 @@ setMethod("collect", # listCols is a list of columns listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf) stopifnot(length(listCols) == ncol) - + # An empty data.frame with 0 columns and number of rows as collected nrow <- length(listCols[[1]]) if (nrow <= 0) { df <- data.frame() } else { - df <- data.frame(row.names = 1 : nrow) + df <- data.frame(row.names = 1 : nrow) } - + # Append columns one by one for (colIndex in 1 : ncol) { # Note: appending a column of list type into a data.frame so that @@ -683,7 +683,7 @@ setMethod("collect", # TODO: more robust check on column of primitive types vec <- do.call(c, col) if (class(vec) != "list") { - df[[names[colIndex]]] <- vec + df[[names[colIndex]]] <- vec } else { # For columns of complex type, be careful to access them. # Get a column of complex type returns a list. diff --git a/R/pkg/inst/tests/test_Serde.R b/R/pkg/inst/tests/test_Serde.R index 009db85da2beb..dddce54d70443 100644 --- a/R/pkg/inst/tests/test_Serde.R +++ b/R/pkg/inst/tests/test_Serde.R @@ -23,7 +23,7 @@ test_that("SerDe of primitive types", { x <- callJStatic("SparkRHandler", "echo", 1L) expect_equal(x, 1L) expect_equal(class(x), "integer") - + x <- callJStatic("SparkRHandler", "echo", 1) expect_equal(x, 1) expect_equal(class(x), "numeric") @@ -31,10 +31,10 @@ test_that("SerDe of primitive types", { x <- callJStatic("SparkRHandler", "echo", TRUE) expect_true(x) expect_equal(class(x), "logical") - + x <- callJStatic("SparkRHandler", "echo", "abc") expect_equal(x, "abc") - expect_equal(class(x), "character") + expect_equal(class(x), "character") }) test_that("SerDe of list of primitive types", { @@ -47,17 +47,17 @@ test_that("SerDe of list of primitive types", { y <- callJStatic("SparkRHandler", "echo", x) expect_equal(x, y) expect_equal(class(y[[1]]), "numeric") - + x <- list(TRUE, FALSE) y <- callJStatic("SparkRHandler", "echo", x) expect_equal(x, y) expect_equal(class(y[[1]]), "logical") - + x <- list("a", "b", "c") y <- callJStatic("SparkRHandler", "echo", x) expect_equal(x, y) expect_equal(class(y[[1]]), "character") - + # Empty list x <- list() y <- callJStatic("SparkRHandler", "echo", x) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index ee48a3dc0cc05..8e22c56824b16 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -608,7 +608,7 @@ test_that("subsetting", { df4 <- df[df$age %in% c(19, 30), 1:2] expect_equal(count(df4), 2) expect_equal(columns(df4), c("name", "age")) - + df5 <- df[df$age %in% c(19), c(1,2)] expect_equal(count(df5), 1) expect_equal(columns(df5), c("name", "age")) From 0fac144f6bd835395059154532d72cdb5dc7ef8d Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 26 Aug 2015 18:14:54 -0700 Subject: [PATCH 099/802] [SPARK-9424] [SQL] Parquet programming guide updates for 1.5 Author: Cheng Lian Closes #8467 from liancheng/spark-9424/parquet-docs-for-1.5. --- docs/sql-programming-guide.md | 45 ++++++++++++++++++++++++++++------- 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 33e7893d7bd0a..e64190b9b209d 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1124,6 +1124,13 @@ a simple schema, and gradually add more columns to the schema as needed. In thi up with multiple Parquet files with different but mutually compatible schemas. The Parquet data source is now able to automatically detect this case and merge schemas of all these files. +Since schema merging is a relatively expensive operation, and is not a necessity in most cases, we +turned it off by default starting from 1.5.0. You may enable it by + +1. setting data source option `mergeSchema` to `true` when reading Parquet files (as shown in the + examples below), or +2. setting the global SQL option `spark.sql.parquet.mergeSchema` to `true`. +
@@ -1143,7 +1150,7 @@ val df2 = sc.makeRDD(6 to 10).map(i => (i, i * 3)).toDF("single", "triple") df2.write.parquet("data/test_table/key=2") // Read the partitioned table -val df3 = sqlContext.read.parquet("data/test_table") +val df3 = sqlContext.read.option("mergeSchema", "true").parquet("data/test_table") df3.printSchema() // The final schema consists of all 3 columns in the Parquet files together @@ -1165,16 +1172,16 @@ df3.printSchema() # Create a simple DataFrame, stored into a partition directory df1 = sqlContext.createDataFrame(sc.parallelize(range(1, 6))\ .map(lambda i: Row(single=i, double=i * 2))) -df1.save("data/test_table/key=1", "parquet") +df1.write.parquet("data/test_table/key=1") # Create another DataFrame in a new partition directory, # adding a new column and dropping an existing column df2 = sqlContext.createDataFrame(sc.parallelize(range(6, 11)) .map(lambda i: Row(single=i, triple=i * 3))) -df2.save("data/test_table/key=2", "parquet") +df2.write.parquet("data/test_table/key=2") # Read the partitioned table -df3 = sqlContext.load("data/test_table", "parquet") +df3 = sqlContext.read.option("mergeSchema", "true").parquet("data/test_table") df3.printSchema() # The final schema consists of all 3 columns in the Parquet files together @@ -1201,7 +1208,7 @@ saveDF(df1, "data/test_table/key=1", "parquet", "overwrite") saveDF(df2, "data/test_table/key=2", "parquet", "overwrite") # Read the partitioned table -df3 <- loadDF(sqlContext, "data/test_table", "parquet") +df3 <- loadDF(sqlContext, "data/test_table", "parquet", mergeSchema="true") printSchema(df3) # The final schema consists of all 3 columns in the Parquet files together @@ -1301,7 +1308,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` spark.sql.parquet.binaryAsString false - Some other Parquet-producing systems, in particular Impala and older versions of Spark SQL, do + Some other Parquet-producing systems, in particular Impala, Hive, and older versions of Spark SQL, do not differentiate between binary data and strings when writing out the Parquet schema. This flag tells Spark SQL to interpret binary data as a string to provide compatibility with these systems. @@ -1310,8 +1317,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` spark.sql.parquet.int96AsTimestamp true - Some Parquet-producing systems, in particular Impala, store Timestamp into INT96. Spark would also - store Timestamp as INT96 because we need to avoid precision lost of the nanoseconds field. This + Some Parquet-producing systems, in particular Impala and Hive, store Timestamp into INT96. This flag tells Spark SQL to interpret INT96 data as a timestamp to provide compatibility with these systems. @@ -1355,6 +1361,9 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext`

Note:

    +
  • + This option is automatically ignored if spark.speculation is turned on. +
  • This option must be set via Hadoop Configuration rather than Spark SQLConf. @@ -1371,6 +1380,26 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext`

    + + spark.sql.parquet.mergeSchema + false + +

    + When true, the Parquet data source merges schemas collected from all data files, otherwise the + schema is picked from the summary file or a random data file if no summary file is available. +

    + + + + spark.sql.parquet.mergeSchema + false + +

    + When true, the Parquet data source merges schemas collected from all data files, otherwise the + schema is picked from the summary file or a random data file if no summary file is available. +

    + + ## JSON Datasets From ce97834dc0cc55eece0e909a4061ca6f2123f60d Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 26 Aug 2015 22:19:11 -0700 Subject: [PATCH 100/802] [SPARK-9964] [PYSPARK] [SQL] PySpark DataFrameReader accept RDD of String for JSON PySpark DataFrameReader should could accept an RDD of Strings (like the Scala version does) for JSON, rather than only taking a path. If this PR is merged, it should be duplicated to cover the other input types (not just JSON). Author: Yanbo Liang Closes #8444 from yanboliang/spark-9964. --- python/pyspark/sql/readwriter.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 78247c8fa7372..3fa6895880a97 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -15,8 +15,14 @@ # limitations under the License. # +import sys + +if sys.version >= '3': + basestring = unicode = str + from py4j.java_gateway import JavaClass +from pyspark import RDD from pyspark.sql import since from pyspark.sql.column import _to_seq from pyspark.sql.types import * @@ -125,23 +131,33 @@ def load(self, path=None, format=None, schema=None, **options): @since(1.4) def json(self, path, schema=None): """ - Loads a JSON file (one object per line) and returns the result as - a :class`DataFrame`. + Loads a JSON file (one object per line) or an RDD of Strings storing JSON objects + (one object per record) and returns the result as a :class`DataFrame`. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. - :param path: string, path to the JSON dataset. + :param path: string represents path to the JSON dataset, + or RDD of Strings storing JSON objects. :param schema: an optional :class:`StructType` for the input schema. - >>> df = sqlContext.read.json('python/test_support/sql/people.json') - >>> df.dtypes + >>> df1 = sqlContext.read.json('python/test_support/sql/people.json') + >>> df1.dtypes + [('age', 'bigint'), ('name', 'string')] + >>> rdd = sc.textFile('python/test_support/sql/people.json') + >>> df2 = sqlContext.read.json(rdd) + >>> df2.dtypes [('age', 'bigint'), ('name', 'string')] """ if schema is not None: self.schema(schema) - return self._df(self._jreader.json(path)) + if isinstance(path, basestring): + return self._df(self._jreader.json(path)) + elif isinstance(path, RDD): + return self._df(self._jreader.json(path._jrdd)) + else: + raise TypeError("path can be only string or RDD") @since(1.4) def table(self, tableName): From e936cf8088a06d6aefce44305f3904bbeb17b432 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Wed, 26 Aug 2015 22:27:31 -0700 Subject: [PATCH 101/802] [SPARK-10219] [SPARKR] Fix varargsToEnv and add test case cc sun-rui davies Author: Shivaram Venkataraman Closes #8475 from shivaram/varargs-fix. --- R/pkg/R/utils.R | 3 ++- R/pkg/inst/tests/test_sparkSQL.R | 6 ++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 4f9f4d9cad2a8..3babcb519378e 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -314,7 +314,8 @@ convertEnvsToList <- function(keys, vals) { # Utility function to capture the varargs into environment object varargsToEnv <- function(...) { - pairs <- as.list(substitute(list(...)))[-1L] + # Based on http://stackoverflow.com/a/3057419/4577954 + pairs <- list(...) env <- new.env() for (name in names(pairs)) { env[[name]] <- pairs[[name]] diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 8e22c56824b16..4b672e115f924 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -1060,6 +1060,12 @@ test_that("parquetFile works with multiple input paths", { parquetDF <- parquetFile(sqlContext, parquetPath, parquetPath2) expect_is(parquetDF, "DataFrame") expect_equal(count(parquetDF), count(df) * 2) + + # Test if varargs works with variables + saveMode <- "overwrite" + mergeSchema <- "true" + parquetPath3 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") + write.df(df, parquetPath2, "parquet", mode = saveMode, mergeSchema = mergeSchema) }) test_that("describe() and summarize() on a DataFrame", { From de0278286cf6db8df53b0b68918ea114f2c77f1f Mon Sep 17 00:00:00 2001 From: Ram Sriharsha Date: Wed, 26 Aug 2015 23:12:55 -0700 Subject: [PATCH 102/802] =?UTF-8?q?[SPARK-10251]=20[CORE]=20some=20common?= =?UTF-8?q?=20types=20are=20not=20registered=20for=20Kryo=20Serializat?= =?UTF-8?q?=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ion by default Author: Ram Sriharsha Closes #8465 from harsha2010/SPARK-10251. --- .../spark/serializer/KryoSerializer.scala | 35 ++++++++++++++++++- .../serializer/KryoSerializerSuite.scala | 30 ++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 048a938507277..b977711e7d5ad 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -22,6 +22,7 @@ import java.nio.ByteBuffer import javax.annotation.Nullable import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import com.esotericsoftware.kryo.{Kryo, KryoException} @@ -38,7 +39,7 @@ import org.apache.spark.network.nio.{GetBlock, GotBlock, PutBlock} import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ -import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf} +import org.apache.spark.util.{Utils, BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf} import org.apache.spark.util.collection.CompactBuffer /** @@ -131,6 +132,38 @@ class KryoSerializer(conf: SparkConf) // our code override the generic serializers in Chill for things like Seq new AllScalaRegistrar().apply(kryo) + // Register types missed by Chill. + // scalastyle:off + kryo.register(classOf[Array[Tuple1[Any]]]) + kryo.register(classOf[Array[Tuple2[Any, Any]]]) + kryo.register(classOf[Array[Tuple3[Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple4[Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple5[Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple6[Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple7[Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple8[Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple9[Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + kryo.register(classOf[Array[Tuple22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]]) + + // scalastyle:on + + kryo.register(None.getClass) + kryo.register(Nil.getClass) + kryo.register(Utils.classForName("scala.collection.immutable.$colon$colon")) + kryo.register(classOf[ArrayBuffer[Any]]) + kryo.setClassLoader(classLoader) kryo } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 8d1c9d17e977e..e428414cf6e85 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -150,6 +150,36 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { mutable.HashMap(1->"one", 2->"two", 3->"three"))) } + test("Bug: SPARK-10251") { + val ser = new KryoSerializer(conf.clone.set("spark.kryo.registrationRequired", "true")) + .newInstance() + def check[T: ClassTag](t: T) { + assert(ser.deserialize[T](ser.serialize(t)) === t) + } + check((1, 3)) + check(Array((1, 3))) + check(List((1, 3))) + check(List[Int]()) + check(List[Int](1, 2, 3)) + check(List[String]()) + check(List[String]("x", "y", "z")) + check(None) + check(Some(1)) + check(Some("hi")) + check(1 -> 1) + check(mutable.ArrayBuffer(1, 2, 3)) + check(mutable.ArrayBuffer("1", "2", "3")) + check(mutable.Map()) + check(mutable.Map(1 -> "one", 2 -> "two")) + check(mutable.Map("one" -> 1, "two" -> 2)) + check(mutable.HashMap(1 -> "one", 2 -> "two")) + check(mutable.HashMap("one" -> 1, "two" -> 2)) + check(List(Some(mutable.HashMap(1->1, 2->2)), None, Some(mutable.HashMap(3->4)))) + check(List( + mutable.HashMap("one" -> 1, "two" -> 2), + mutable.HashMap(1->"one", 2->"two", 3->"three"))) + } + test("ranges") { val ser = new KryoSerializer(conf).newInstance() def check[T: ClassTag](t: T) { From 9625d13d575c97bbff264f6a94838aae72c9202d Mon Sep 17 00:00:00 2001 From: Moussa Taifi Date: Thu, 27 Aug 2015 10:34:47 +0100 Subject: [PATCH 103/802] [DOCS] [STREAMING] [KAFKA] Fix typo in exactly once semantics Fix Typo in exactly once semantics [Semantics of output operations] link Author: Moussa Taifi Closes #8468 from moutai/patch-3. --- docs/streaming-kafka-integration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 7571e22575efd..5db39ae54a274 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -82,7 +82,7 @@ This approach has the following advantages over the receiver-based approach (i.e - *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write Ahead Log. This second approach eliminates the problem as there is no receiver, and hence no need for Write Ahead Logs. As long as you have sufficient Kafka retention, messages can be recovered from Kafka. -- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semanitcs of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). +- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semantics of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). Note that one disadvantage of this approach is that it does not update offsets in Zookeeper, hence Zookeeper-based Kafka monitoring tools will not show progress. However, you can access the offsets processed by this approach in each batch and update Zookeeper yourself (see below). From 1650f6f56ed4b7f1a7f645c9e8d5ac533464bd78 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Thu, 27 Aug 2015 10:44:44 +0100 Subject: [PATCH 104/802] [SPARK-10254] [ML] Removes Guava dependencies in spark.ml.feature JavaTests * Replaces `com.google.common` dependencies with `java.util.Arrays` * Small clean up in `JavaNormalizerSuite` Author: Feynman Liang Closes #8445 from feynmanliang/SPARK-10254. --- .../apache/spark/ml/feature/JavaBucketizerSuite.java | 5 +++-- .../org/apache/spark/ml/feature/JavaDCTSuite.java | 5 +++-- .../apache/spark/ml/feature/JavaHashingTFSuite.java | 5 +++-- .../apache/spark/ml/feature/JavaNormalizerSuite.java | 11 +++++------ .../org/apache/spark/ml/feature/JavaPCASuite.java | 4 ++-- .../ml/feature/JavaPolynomialExpansionSuite.java | 5 +++-- .../spark/ml/feature/JavaStandardScalerSuite.java | 4 ++-- .../apache/spark/ml/feature/JavaTokenizerSuite.java | 6 ++++-- .../spark/ml/feature/JavaVectorIndexerSuite.java | 5 ++--- .../spark/ml/feature/JavaVectorSlicerSuite.java | 4 ++-- .../apache/spark/ml/feature/JavaWord2VecSuite.java | 11 ++++++----- 11 files changed, 35 insertions(+), 30 deletions(-) diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java index d5bd230a957a1..47d68de599da2 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java @@ -17,7 +17,8 @@ package org.apache.spark.ml.feature; -import com.google.common.collect.Lists; +import java.util.Arrays; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -54,7 +55,7 @@ public void tearDown() { public void bucketizerTest() { double[] splits = {-0.5, 0.0, 0.5}; - JavaRDD data = jsc.parallelize(Lists.newArrayList( + JavaRDD data = jsc.parallelize(Arrays.asList( RowFactory.create(-0.5), RowFactory.create(-0.3), RowFactory.create(0.0), diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java index 845eed61c45c6..0f6ec64d97d36 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java @@ -17,7 +17,8 @@ package org.apache.spark.ml.feature; -import com.google.common.collect.Lists; +import java.util.Arrays; + import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D; import org.junit.After; import org.junit.Assert; @@ -56,7 +57,7 @@ public void tearDown() { @Test public void javaCompatibilityTest() { double[] input = new double[] {1D, 2D, 3D, 4D}; - JavaRDD data = jsc.parallelize(Lists.newArrayList( + JavaRDD data = jsc.parallelize(Arrays.asList( RowFactory.create(Vectors.dense(input)) )); DataFrame dataset = jsql.createDataFrame(data, new StructType(new StructField[]{ diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java index 599e9cfd23ad4..03dd5369bddf7 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java @@ -17,7 +17,8 @@ package org.apache.spark.ml.feature; -import com.google.common.collect.Lists; +import java.util.Arrays; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -54,7 +55,7 @@ public void tearDown() { @Test public void hashingTF() { - JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + JavaRDD jrdd = jsc.parallelize(Arrays.asList( RowFactory.create(0.0, "Hi I heard about Spark"), RowFactory.create(0.0, "I wish Java could use case classes"), RowFactory.create(1.0, "Logistic regression models are neat") diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java index d82f3b7e8c076..e17d549c5059b 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java @@ -17,15 +17,15 @@ package org.apache.spark.ml.feature; -import java.util.List; +import java.util.Arrays; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.api.java.JavaRDD; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; @@ -48,13 +48,12 @@ public void tearDown() { @Test public void normalizer() { // The tests are to check Java compatibility. - List points = Lists.newArrayList( + JavaRDD points = jsc.parallelize(Arrays.asList( new VectorIndexerSuite.FeatureData(Vectors.dense(0.0, -2.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) - ); - DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2), - VectorIndexerSuite.FeatureData.class); + )); + DataFrame dataFrame = jsql.createDataFrame(points, VectorIndexerSuite.FeatureData.class); Normalizer normalizer = new Normalizer() .setInputCol("features") .setOutputCol("normFeatures"); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java index 5cf43fec6f29e..e8f329f9cf29e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java @@ -18,11 +18,11 @@ package org.apache.spark.ml.feature; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import scala.Tuple2; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -78,7 +78,7 @@ public Vector getExpected() { @Test public void testPCA() { - List points = Lists.newArrayList( + List points = Arrays.asList( Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0}), Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java index 5e8211c2c5118..834fedbb59e1b 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java @@ -17,7 +17,8 @@ package org.apache.spark.ml.feature; -import com.google.common.collect.Lists; +import java.util.Arrays; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -59,7 +60,7 @@ public void polynomialExpansionTest() { .setOutputCol("polyFeatures") .setDegree(3); - JavaRDD data = jsc.parallelize(Lists.newArrayList( + JavaRDD data = jsc.parallelize(Arrays.asList( RowFactory.create( Vectors.dense(-2.0, 2.3), Vectors.dense(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17) diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java index 74eb2733f06ef..ed74363f59e34 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java @@ -17,9 +17,9 @@ package org.apache.spark.ml.feature; +import java.util.Arrays; import java.util.List; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -48,7 +48,7 @@ public void tearDown() { @Test public void standardScaler() { // The tests are to check Java compatibility. - List points = Lists.newArrayList( + List points = Arrays.asList( new VectorIndexerSuite.FeatureData(Vectors.dense(0.0, -2.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java index 3806f650025b2..02309ce63219a 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java @@ -17,7 +17,8 @@ package org.apache.spark.ml.feature; -import com.google.common.collect.Lists; +import java.util.Arrays; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -54,7 +55,8 @@ public void regexTokenizer() { .setGaps(true) .setMinTokenLength(3); - JavaRDD rdd = jsc.parallelize(Lists.newArrayList( + + JavaRDD rdd = jsc.parallelize(Arrays.asList( new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}), new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"}) )); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java index c7ae5468b9429..bfcca62fa1c98 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import java.util.Map; @@ -26,8 +27,6 @@ import org.junit.Before; import org.junit.Test; -import com.google.common.collect.Lists; - import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData; import org.apache.spark.mllib.linalg.Vectors; @@ -52,7 +51,7 @@ public void tearDown() { @Test public void vectorIndexerAPI() { // The tests are to check Java compatibility. - List points = Lists.newArrayList( + List points = Arrays.asList( new FeatureData(Vectors.dense(0.0, -2.0)), new FeatureData(Vectors.dense(1.0, 3.0)), new FeatureData(Vectors.dense(1.0, 4.0)) diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java index 56988b9fb29cb..f953361427586 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature; -import com.google.common.collect.Lists; +import java.util.Arrays; import org.junit.After; import org.junit.Assert; @@ -63,7 +63,7 @@ public void vectorSlice() { }; AttributeGroup group = new AttributeGroup("userFeatures", attrs); - JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + JavaRDD jrdd = jsc.parallelize(Arrays.asList( RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) )); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java index 39c70157f83c0..70f5ad9432212 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java @@ -17,7 +17,8 @@ package org.apache.spark.ml.feature; -import com.google.common.collect.Lists; +import java.util.Arrays; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -50,10 +51,10 @@ public void tearDown() { @Test public void testJavaWord2Vec() { - JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(Lists.newArrayList("Hi I heard about Spark".split(" "))), - RowFactory.create(Lists.newArrayList("I wish Java could use case classes".split(" "))), - RowFactory.create(Lists.newArrayList("Logistic regression models are neat".split(" "))) + JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))), + RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))), + RowFactory.create(Arrays.asList("Logistic regression models are neat".split(" "))) )); StructType schema = new StructType(new StructField[]{ new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) From 75d62307946283b03bec6aaf1bdd4f2b08c93915 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Thu, 27 Aug 2015 10:45:35 +0100 Subject: [PATCH 105/802] [SPARK-10255] [ML] Removes Guava dependencies from spark.ml.param JavaTests Author: Feynman Liang Closes #8446 from feynmanliang/SPARK-10255. --- .../java/org/apache/spark/ml/param/JavaParamsSuite.java | 7 ++++--- .../java/org/apache/spark/ml/param/JavaTestParams.java | 5 ++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java index 9890155e9f865..fa777f3d42a9a 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java @@ -17,7 +17,8 @@ package org.apache.spark.ml.param; -import com.google.common.collect.Lists; +import java.util.Arrays; + import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -61,7 +62,7 @@ public void testParamValidate() { ParamValidators.ltEq(1.0); ParamValidators.inRange(0, 1, true, false); ParamValidators.inRange(0, 1); - ParamValidators.inArray(Lists.newArrayList(0, 1, 3)); - ParamValidators.inArray(Lists.newArrayList("a", "b")); + ParamValidators.inArray(Arrays.asList(0, 1, 3)); + ParamValidators.inArray(Arrays.asList("a", "b")); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java index dc6ce8061f62b..65841182df9b4 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java @@ -17,10 +17,9 @@ package org.apache.spark.ml.param; +import java.util.Arrays; import java.util.List; -import com.google.common.collect.Lists; - import org.apache.spark.ml.util.Identifiable$; /** @@ -89,7 +88,7 @@ private void init() { myIntParam_ = new IntParam(this, "myIntParam", "this is an int param", ParamValidators.gt(0)); myDoubleParam_ = new DoubleParam(this, "myDoubleParam", "this is a double param", ParamValidators.inRange(0.0, 1.0)); - List validStrings = Lists.newArrayList("a", "b"); + List validStrings = Arrays.asList("a", "b"); myStringParam_ = new Param(this, "myStringParam", "this is a string param", ParamValidators.inArray(validStrings)); myDoubleArrayParam_ = From 1a446f75b6cac46caea0217a66abeb226946ac71 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Thu, 27 Aug 2015 10:46:18 +0100 Subject: [PATCH 106/802] [SPARK-10256] [ML] Removes guava dependency from spark.ml.classification JavaTests Author: Feynman Liang Closes #8447 from feynmanliang/SPARK-10256. --- .../apache/spark/ml/classification/JavaNaiveBayesSuite.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java index a700c9cddb206..8fd7bf55a2e5d 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java @@ -18,8 +18,8 @@ package org.apache.spark.ml.classification; import java.io.Serializable; +import java.util.Arrays; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -74,7 +74,7 @@ public void naiveBayesDefaultParams() { @Test public void testNaiveBayes() { - JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + JavaRDD jrdd = jsc.parallelize(Arrays.asList( RowFactory.create(0.0, Vectors.dense(1.0, 0.0, 0.0)), RowFactory.create(0.0, Vectors.dense(2.0, 0.0, 0.0)), RowFactory.create(1.0, Vectors.dense(0.0, 1.0, 0.0)), From b02e8187225d1765f67ce38864dfaca487be8a44 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Thu, 27 Aug 2015 11:07:37 +0100 Subject: [PATCH 107/802] [SPARK-9613] [HOTFIX] Fix usage of JavaConverters removed in Scala 2.11 Fix for [JavaConverters.asJavaListConverter](http://www.scala-lang.org/api/2.10.5/index.html#scala.collection.JavaConverters$) being removed in 2.11.7 and hence the build fails with the 2.11 profile enabled. Tested with the default 2.10 and 2.11 profiles. BUILD SUCCESS in both cases. Build for 2.10: ./build/mvn -Pyarn -Phadoop-2.6 -Dhadoop.version=2.7.1 -DskipTests clean install and 2.11: ./dev/change-scala-version.sh 2.11 ./build/mvn -Pyarn -Phadoop-2.6 -Dhadoop.version=2.7.1 -Dscala-2.11 -DskipTests clean install Author: Jacek Laskowski Closes #8479 from jaceklaskowski/SPARK-9613-hotfix. --- .../org/apache/spark/ml/classification/JavaOneVsRestSuite.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java index 2744e020e9e49..253cabf0133d0 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java @@ -55,7 +55,7 @@ public void setUp() { double[] xMean = {5.843, 3.057, 3.758, 1.199}; double[] xVariance = {0.6856, 0.1899, 3.116, 0.581}; - List points = JavaConverters.asJavaListConverter( + List points = JavaConverters.seqAsJavaListConverter( generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42) ).asJava(); datasetRDD = jsc.parallelize(points, 2); From e1f4de4a7d15d4ca4b5c64ff929ac3980f5d706f Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Thu, 27 Aug 2015 18:46:41 +0100 Subject: [PATCH 108/802] [SPARK-10257] [MLLIB] Removes Guava from all spark.mllib Java tests * Replaces instances of `Lists.newArrayList` with `Arrays.asList` * Replaces `commons.lang.StringUtils` over `com.google.collections.Strings` * Replaces `List` interface over `ArrayList` implementations This PR along with #8445 #8446 #8447 completely removes all `com.google.collections.Lists` dependencies within mllib's Java tests. Author: Feynman Liang Closes #8451 from feynmanliang/SPARK-10257. --- .../JavaStreamingLogisticRegressionSuite.java | 10 +++---- .../clustering/JavaGaussianMixtureSuite.java | 4 +-- .../mllib/clustering/JavaKMeansSuite.java | 9 +++---- .../clustering/JavaStreamingKMeansSuite.java | 10 +++---- .../spark/mllib/feature/JavaTfIdfSuite.java | 19 +++++++------ .../mllib/feature/JavaWord2VecSuite.java | 6 ++--- .../mllib/fpm/JavaAssociationRulesSuite.java | 5 ++-- .../spark/mllib/fpm/JavaFPGrowthSuite.java | 17 ++++++------ .../spark/mllib/linalg/JavaVectorsSuite.java | 5 ++-- .../mllib/random/JavaRandomRDDsSuite.java | 27 ++++++++++--------- .../mllib/recommendation/JavaALSSuite.java | 5 ++-- .../JavaIsotonicRegressionSuite.java | 7 ++--- .../JavaStreamingLinearRegressionSuite.java | 10 +++---- .../spark/mllib/stat/JavaStatisticsSuite.java | 11 ++++---- 14 files changed, 71 insertions(+), 74 deletions(-) diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java index 55787f8606d48..c9e5ee22f3273 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java @@ -18,11 +18,11 @@ package org.apache.spark.mllib.classification; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import scala.Tuple2; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -60,16 +60,16 @@ public void tearDown() { @Test @SuppressWarnings("unchecked") public void javaAPI() { - List trainingBatch = Lists.newArrayList( + List trainingBatch = Arrays.asList( new LabeledPoint(1.0, Vectors.dense(1.0)), new LabeledPoint(0.0, Vectors.dense(0.0))); JavaDStream training = - attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2); - List> testBatch = Lists.newArrayList( + attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2); + List> testBatch = Arrays.asList( new Tuple2(10, Vectors.dense(1.0)), new Tuple2(11, Vectors.dense(0.0))); JavaPairDStream test = JavaPairDStream.fromJavaDStream( - attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2)); + attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2)); StreamingLogisticRegressionWithSGD slr = new StreamingLogisticRegressionWithSGD() .setNumIterations(2) .setInitialWeights(Vectors.dense(0.0)); diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java index 467a7a69e8f30..123f78da54e34 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java @@ -18,9 +18,9 @@ package org.apache.spark.mllib.clustering; import java.io.Serializable; +import java.util.Arrays; import java.util.List; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -48,7 +48,7 @@ public void tearDown() { @Test public void runGaussianMixture() { - List points = Lists.newArrayList( + List points = Arrays.asList( Vectors.dense(1.0, 2.0, 6.0), Vectors.dense(1.0, 3.0, 0.0), Vectors.dense(1.0, 4.0, 6.0) diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java index 31676e64025d0..ad06676c72ac6 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.mllib.clustering; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import org.junit.After; @@ -25,8 +26,6 @@ import org.junit.Test; import static org.junit.Assert.*; -import com.google.common.collect.Lists; - import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; @@ -48,7 +47,7 @@ public void tearDown() { @Test public void runKMeansUsingStaticMethods() { - List points = Lists.newArrayList( + List points = Arrays.asList( Vectors.dense(1.0, 2.0, 6.0), Vectors.dense(1.0, 3.0, 0.0), Vectors.dense(1.0, 4.0, 6.0) @@ -67,7 +66,7 @@ public void runKMeansUsingStaticMethods() { @Test public void runKMeansUsingConstructor() { - List points = Lists.newArrayList( + List points = Arrays.asList( Vectors.dense(1.0, 2.0, 6.0), Vectors.dense(1.0, 3.0, 0.0), Vectors.dense(1.0, 4.0, 6.0) @@ -90,7 +89,7 @@ public void runKMeansUsingConstructor() { @Test public void testPredictJavaRDD() { - List points = Lists.newArrayList( + List points = Arrays.asList( Vectors.dense(1.0, 2.0, 6.0), Vectors.dense(1.0, 3.0, 0.0), Vectors.dense(1.0, 4.0, 6.0) diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java index 3b0e879eec77f..d644766d1e54d 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java @@ -18,11 +18,11 @@ package org.apache.spark.mllib.clustering; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import scala.Tuple2; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -60,16 +60,16 @@ public void tearDown() { @Test @SuppressWarnings("unchecked") public void javaAPI() { - List trainingBatch = Lists.newArrayList( + List trainingBatch = Arrays.asList( Vectors.dense(1.0), Vectors.dense(0.0)); JavaDStream training = - attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2); - List> testBatch = Lists.newArrayList( + attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2); + List> testBatch = Arrays.asList( new Tuple2(10, Vectors.dense(1.0)), new Tuple2(11, Vectors.dense(0.0))); JavaPairDStream test = JavaPairDStream.fromJavaDStream( - attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2)); + attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2)); StreamingKMeans skmeans = new StreamingKMeans() .setK(1) .setDecayFactor(1.0) diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java index fbc26167ce66f..8a320afa4b13d 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java @@ -18,14 +18,13 @@ package org.apache.spark.mllib.feature; import java.io.Serializable; -import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import com.google.common.collect.Lists; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -50,10 +49,10 @@ public void tfIdf() { // The tests are to check Java compatibility. HashingTF tf = new HashingTF(); @SuppressWarnings("unchecked") - JavaRDD> documents = sc.parallelize(Lists.newArrayList( - Lists.newArrayList("this is a sentence".split(" ")), - Lists.newArrayList("this is another sentence".split(" ")), - Lists.newArrayList("this is still a sentence".split(" "))), 2); + JavaRDD> documents = sc.parallelize(Arrays.asList( + Arrays.asList("this is a sentence".split(" ")), + Arrays.asList("this is another sentence".split(" ")), + Arrays.asList("this is still a sentence".split(" "))), 2); JavaRDD termFreqs = tf.transform(documents); termFreqs.collect(); IDF idf = new IDF(); @@ -70,10 +69,10 @@ public void tfIdfMinimumDocumentFrequency() { // The tests are to check Java compatibility. HashingTF tf = new HashingTF(); @SuppressWarnings("unchecked") - JavaRDD> documents = sc.parallelize(Lists.newArrayList( - Lists.newArrayList("this is a sentence".split(" ")), - Lists.newArrayList("this is another sentence".split(" ")), - Lists.newArrayList("this is still a sentence".split(" "))), 2); + JavaRDD> documents = sc.parallelize(Arrays.asList( + Arrays.asList("this is a sentence".split(" ")), + Arrays.asList("this is another sentence".split(" ")), + Arrays.asList("this is still a sentence".split(" "))), 2); JavaRDD termFreqs = tf.transform(documents); termFreqs.collect(); IDF idf = new IDF(2); diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java index fb7afe8c6434b..e13ed07e283dd 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java @@ -18,11 +18,11 @@ package org.apache.spark.mllib.feature; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import scala.Tuple2; -import com.google.common.collect.Lists; import com.google.common.base.Strings; import org.junit.After; import org.junit.Assert; @@ -51,8 +51,8 @@ public void tearDown() { public void word2Vec() { // The tests are to check Java compatibility. String sentence = Strings.repeat("a b ", 100) + Strings.repeat("a c ", 10); - List words = Lists.newArrayList(sentence.split(" ")); - List> localDoc = Lists.newArrayList(words, words); + List words = Arrays.asList(sentence.split(" ")); + List> localDoc = Arrays.asList(words, words); JavaRDD> doc = sc.parallelize(localDoc); Word2Vec word2vec = new Word2Vec() .setVectorSize(10) diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java index d7c2cb3ae2067..2bef7a8609757 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java @@ -17,17 +17,16 @@ package org.apache.spark.mllib.fpm; import java.io.Serializable; +import java.util.Arrays; import org.junit.After; import org.junit.Before; import org.junit.Test; -import com.google.common.collect.Lists; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset; - public class JavaAssociationRulesSuite implements Serializable { private transient JavaSparkContext sc; @@ -46,7 +45,7 @@ public void tearDown() { public void runAssociationRules() { @SuppressWarnings("unchecked") - JavaRDD> freqItemsets = sc.parallelize(Lists.newArrayList( + JavaRDD> freqItemsets = sc.parallelize(Arrays.asList( new FreqItemset(new String[] {"a"}, 15L), new FreqItemset(new String[] {"b"}, 35L), new FreqItemset(new String[] {"a", "b"}, 12L) diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java index 9ce2c52dca8b6..154f75d75e4a6 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java @@ -18,13 +18,12 @@ package org.apache.spark.mllib.fpm; import java.io.Serializable; -import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import org.junit.After; import org.junit.Before; import org.junit.Test; -import com.google.common.collect.Lists; import static org.junit.Assert.*; import org.apache.spark.api.java.JavaRDD; @@ -48,13 +47,13 @@ public void tearDown() { public void runFPGrowth() { @SuppressWarnings("unchecked") - JavaRDD> rdd = sc.parallelize(Lists.newArrayList( - Lists.newArrayList("r z h k p".split(" ")), - Lists.newArrayList("z y x w v u t s".split(" ")), - Lists.newArrayList("s x o n r".split(" ")), - Lists.newArrayList("x z y m t s q e".split(" ")), - Lists.newArrayList("z".split(" ")), - Lists.newArrayList("x z y r q t p".split(" "))), 2); + JavaRDD> rdd = sc.parallelize(Arrays.asList( + Arrays.asList("r z h k p".split(" ")), + Arrays.asList("z y x w v u t s".split(" ")), + Arrays.asList("s x o n r".split(" ")), + Arrays.asList("x z y m t s q e".split(" ")), + Arrays.asList("z".split(" ")), + Arrays.asList("x z y r q t p".split(" "))), 2); FPGrowthModel model = new FPGrowth() .setMinSupport(0.5) diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java index 1421067dc61ed..77c8c6274f374 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java @@ -18,11 +18,10 @@ package org.apache.spark.mllib.linalg; import java.io.Serializable; +import java.util.Arrays; import scala.Tuple2; -import com.google.common.collect.Lists; - import org.junit.Test; import static org.junit.Assert.*; @@ -37,7 +36,7 @@ public void denseArrayConstruction() { @Test public void sparseArrayConstruction() { @SuppressWarnings("unchecked") - Vector v = Vectors.sparse(3, Lists.>newArrayList( + Vector v = Vectors.sparse(3, Arrays.asList( new Tuple2(0, 2.0), new Tuple2(2, 3.0))); assertArrayEquals(new double[]{2.0, 0.0, 3.0}, v.toArray(), 0.0); diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java index fcc13c00cbdc5..33d81b1e9592b 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java @@ -17,7 +17,8 @@ package org.apache.spark.mllib.random; -import com.google.common.collect.Lists; +import java.util.Arrays; + import org.apache.spark.api.java.JavaRDD; import org.junit.Assert; import org.junit.After; @@ -51,7 +52,7 @@ public void testUniformRDD() { JavaDoubleRDD rdd1 = uniformJavaRDD(sc, m); JavaDoubleRDD rdd2 = uniformJavaRDD(sc, m, p); JavaDoubleRDD rdd3 = uniformJavaRDD(sc, m, p, seed); - for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -64,7 +65,7 @@ public void testNormalRDD() { JavaDoubleRDD rdd1 = normalJavaRDD(sc, m); JavaDoubleRDD rdd2 = normalJavaRDD(sc, m, p); JavaDoubleRDD rdd3 = normalJavaRDD(sc, m, p, seed); - for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -79,7 +80,7 @@ public void testLNormalRDD() { JavaDoubleRDD rdd1 = logNormalJavaRDD(sc, mean, std, m); JavaDoubleRDD rdd2 = logNormalJavaRDD(sc, mean, std, m, p); JavaDoubleRDD rdd3 = logNormalJavaRDD(sc, mean, std, m, p, seed); - for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -93,7 +94,7 @@ public void testPoissonRDD() { JavaDoubleRDD rdd1 = poissonJavaRDD(sc, mean, m); JavaDoubleRDD rdd2 = poissonJavaRDD(sc, mean, m, p); JavaDoubleRDD rdd3 = poissonJavaRDD(sc, mean, m, p, seed); - for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -107,7 +108,7 @@ public void testExponentialRDD() { JavaDoubleRDD rdd1 = exponentialJavaRDD(sc, mean, m); JavaDoubleRDD rdd2 = exponentialJavaRDD(sc, mean, m, p); JavaDoubleRDD rdd3 = exponentialJavaRDD(sc, mean, m, p, seed); - for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -122,7 +123,7 @@ public void testGammaRDD() { JavaDoubleRDD rdd1 = gammaJavaRDD(sc, shape, scale, m); JavaDoubleRDD rdd2 = gammaJavaRDD(sc, shape, scale, m, p); JavaDoubleRDD rdd3 = gammaJavaRDD(sc, shape, scale, m, p, seed); - for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -138,7 +139,7 @@ public void testUniformVectorRDD() { JavaRDD rdd1 = uniformJavaVectorRDD(sc, m, n); JavaRDD rdd2 = uniformJavaVectorRDD(sc, m, n, p); JavaRDD rdd3 = uniformJavaVectorRDD(sc, m, n, p, seed); - for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -154,7 +155,7 @@ public void testNormalVectorRDD() { JavaRDD rdd1 = normalJavaVectorRDD(sc, m, n); JavaRDD rdd2 = normalJavaVectorRDD(sc, m, n, p); JavaRDD rdd3 = normalJavaVectorRDD(sc, m, n, p, seed); - for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -172,7 +173,7 @@ public void testLogNormalVectorRDD() { JavaRDD rdd1 = logNormalJavaVectorRDD(sc, mean, std, m, n); JavaRDD rdd2 = logNormalJavaVectorRDD(sc, mean, std, m, n, p); JavaRDD rdd3 = logNormalJavaVectorRDD(sc, mean, std, m, n, p, seed); - for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -189,7 +190,7 @@ public void testPoissonVectorRDD() { JavaRDD rdd1 = poissonJavaVectorRDD(sc, mean, m, n); JavaRDD rdd2 = poissonJavaVectorRDD(sc, mean, m, n, p); JavaRDD rdd3 = poissonJavaVectorRDD(sc, mean, m, n, p, seed); - for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -206,7 +207,7 @@ public void testExponentialVectorRDD() { JavaRDD rdd1 = exponentialJavaVectorRDD(sc, mean, m, n); JavaRDD rdd2 = exponentialJavaVectorRDD(sc, mean, m, n, p); JavaRDD rdd3 = exponentialJavaVectorRDD(sc, mean, m, n, p, seed); - for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -224,7 +225,7 @@ public void testGammaVectorRDD() { JavaRDD rdd1 = gammaJavaVectorRDD(sc, shape, scale, m, n); JavaRDD rdd2 = gammaJavaVectorRDD(sc, shape, scale, m, n, p); JavaRDD rdd3 = gammaJavaVectorRDD(sc, shape, scale, m, n, p, seed); - for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java index af688c504cf1e..271dda4662e0d 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java @@ -18,12 +18,12 @@ package org.apache.spark.mllib.recommendation; import java.io.Serializable; +import java.util.ArrayList; import java.util.List; import scala.Tuple2; import scala.Tuple3; -import com.google.common.collect.Lists; import org.jblas.DoubleMatrix; import org.junit.After; import org.junit.Assert; @@ -56,8 +56,7 @@ void validatePrediction( double matchThreshold, boolean implicitPrefs, DoubleMatrix truePrefs) { - List> localUsersProducts = - Lists.newArrayListWithCapacity(users * products); + List> localUsersProducts = new ArrayList(users * products); for (int u=0; u < users; ++u) { for (int p=0; p < products; ++p) { localUsersProducts.add(new Tuple2(u, p)); diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java index d38fc91ace3cf..32c2f4f3395b7 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java @@ -18,11 +18,12 @@ package org.apache.spark.mllib.regression; import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import scala.Tuple3; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -36,7 +37,7 @@ public class JavaIsotonicRegressionSuite implements Serializable { private transient JavaSparkContext sc; private List> generateIsotonicInput(double[] labels) { - List> input = Lists.newArrayList(); + ArrayList> input = new ArrayList(labels.length); for (int i = 1; i <= labels.length; i++) { input.add(new Tuple3(labels[i-1], (double) i, 1d)); @@ -77,7 +78,7 @@ public void testIsotonicRegressionPredictionsJavaRDD() { IsotonicRegressionModel model = runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12}); - JavaDoubleRDD testRDD = sc.parallelizeDoubles(Lists.newArrayList(0.0, 1.0, 9.5, 12.0, 13.0)); + JavaDoubleRDD testRDD = sc.parallelizeDoubles(Arrays.asList(0.0, 1.0, 9.5, 12.0, 13.0)); List predictions = model.predict(testRDD).collect(); Assert.assertTrue(predictions.get(0) == 1d); diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java index 899c4ea607869..dbf6488d41085 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java @@ -18,11 +18,11 @@ package org.apache.spark.mllib.regression; import java.io.Serializable; +import java.util.Arrays; import java.util.List; import scala.Tuple2; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -59,16 +59,16 @@ public void tearDown() { @Test @SuppressWarnings("unchecked") public void javaAPI() { - List trainingBatch = Lists.newArrayList( + List trainingBatch = Arrays.asList( new LabeledPoint(1.0, Vectors.dense(1.0)), new LabeledPoint(0.0, Vectors.dense(0.0))); JavaDStream training = - attachTestInputStream(ssc, Lists.newArrayList(trainingBatch, trainingBatch), 2); - List> testBatch = Lists.newArrayList( + attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2); + List> testBatch = Arrays.asList( new Tuple2(10, Vectors.dense(1.0)), new Tuple2(11, Vectors.dense(0.0))); JavaPairDStream test = JavaPairDStream.fromJavaDStream( - attachTestInputStream(ssc, Lists.newArrayList(testBatch, testBatch), 2)); + attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2)); StreamingLinearRegressionWithSGD slr = new StreamingLinearRegressionWithSGD() .setNumIterations(2) .setInitialWeights(Vectors.dense(0.0)); diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java index eb4e3698624bc..4795809e47a46 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java @@ -19,7 +19,8 @@ import java.io.Serializable; -import com.google.common.collect.Lists; +import java.util.Arrays; + import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -50,8 +51,8 @@ public void tearDown() { @Test public void testCorr() { - JavaRDD x = sc.parallelize(Lists.newArrayList(1.0, 2.0, 3.0, 4.0)); - JavaRDD y = sc.parallelize(Lists.newArrayList(1.1, 2.2, 3.1, 4.3)); + JavaRDD x = sc.parallelize(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + JavaRDD y = sc.parallelize(Arrays.asList(1.1, 2.2, 3.1, 4.3)); Double corr1 = Statistics.corr(x, y); Double corr2 = Statistics.corr(x, y, "pearson"); @@ -61,7 +62,7 @@ public void testCorr() { @Test public void kolmogorovSmirnovTest() { - JavaDoubleRDD data = sc.parallelizeDoubles(Lists.newArrayList(0.2, 1.0, -1.0, 2.0)); + JavaDoubleRDD data = sc.parallelizeDoubles(Arrays.asList(0.2, 1.0, -1.0, 2.0)); KolmogorovSmirnovTestResult testResult1 = Statistics.kolmogorovSmirnovTest(data, "norm"); KolmogorovSmirnovTestResult testResult2 = Statistics.kolmogorovSmirnovTest( data, "norm", 0.0, 1.0); @@ -69,7 +70,7 @@ public void kolmogorovSmirnovTest() { @Test public void chiSqTest() { - JavaRDD data = sc.parallelize(Lists.newArrayList( + JavaRDD data = sc.parallelize(Arrays.asList( new LabeledPoint(0.0, Vectors.dense(0.1, 2.3)), new LabeledPoint(1.0, Vectors.dense(1.5, 5.1)), new LabeledPoint(0.0, Vectors.dense(2.4, 8.1)))); From fdd466bed7a7151dd066d732ef98d225f4acda4a Mon Sep 17 00:00:00 2001 From: Vyacheslav Baranov Date: Thu, 27 Aug 2015 18:56:18 +0100 Subject: [PATCH 109/802] [SPARK-10182] [MLLIB] GeneralizedLinearModel doesn't unpersist cached data `GeneralizedLinearModel` creates a cached RDD when building a model. It's inconvenient, since these RDDs flood the memory when building several models in a row, so useful data might get evicted from the cache. The proposed solution is to always cache the dataset & remove the warning. There's a caveat though: input dataset gets evaluated twice, in line 270 when fitting `StandardScaler` for the first time, and when running optimizer for the second time. So, it might worth to return removed warning. Another possible solution is to disable caching entirely & return removed warning. I don't really know what approach is better. Author: Vyacheslav Baranov Closes #8395 from SlavikBaranov/SPARK-10182. --- .../spark/mllib/regression/GeneralizedLinearAlgorithm.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 7e3b4d5648fe3..8f657bfb9c730 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -359,6 +359,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] + " parent RDDs are also uncached.") } + // Unpersist cached data + if (data.getStorageLevel != StorageLevel.NONE) { + data.unpersist(false) + } + createModel(weights, intercept) } } From dc86a227e4fc8a9d8c3e8c68da8dff9298447fd0 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 27 Aug 2015 11:45:15 -0700 Subject: [PATCH 110/802] [SPARK-9148] [SPARK-10252] [SQL] Update SQL Programming Guide Author: Michael Armbrust Closes #8441 from marmbrus/documentation. --- docs/sql-programming-guide.md | 92 +++++++++++++++++++++++++++-------- 1 file changed, 73 insertions(+), 19 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index e64190b9b209d..99fec6c7785af 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -11,7 +11,7 @@ title: Spark SQL and DataFrames Spark SQL is a Spark module for structured data processing. It provides a programming abstraction called DataFrames and can also act as distributed SQL query engine. -For how to enable Hive support, please refer to the [Hive Tables](#hive-tables) section. +Spark SQL can also be used to read data from an existing Hive installation. For more on how to configure this feature, please refer to the [Hive Tables](#hive-tables) section. # DataFrames @@ -213,6 +213,11 @@ df.groupBy("age").count().show() // 30 1 {% endhighlight %} +For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/scala/index.html#org.apache.spark.sql.DataFrame). + +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/scala/index.html#org.apache.spark.sql.DataFrame). + +
@@ -263,6 +268,10 @@ df.groupBy("age").count().show(); // 30 1 {% endhighlight %} +For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/java/org/apache/spark/sql/DataFrame.html). + +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/java/org/apache/spark/sql/functions.html). +
@@ -320,6 +329,10 @@ df.groupBy("age").count().show() {% endhighlight %} +For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/python/pyspark.sql.html#pyspark.sql.DataFrame). + +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/python/pyspark.sql.html#module-pyspark.sql.functions). +
@@ -370,10 +383,13 @@ showDF(count(groupBy(df, "age"))) {% endhighlight %} -
+For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/R/index.html). + +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/R/index.html).
+ ## Running SQL Queries Programmatically @@ -870,12 +886,11 @@ saveDF(select(df, "name", "age"), "namesAndAges.parquet", "parquet") Save operations can optionally take a `SaveMode`, that specifies how to handle existing data if present. It is important to realize that these save modes do not utilize any locking and are not -atomic. Thus, it is not safe to have multiple writers attempting to write to the same location. -Additionally, when performing a `Overwrite`, the data will be deleted before writing out the +atomic. Additionally, when performing a `Overwrite`, the data will be deleted before writing out the new data. - + @@ -1671,12 +1686,12 @@ results <- collect(sql(sqlContext, "FROM src SELECT key, value")) ### Interacting with Different Versions of Hive Metastore One of the most important pieces of Spark SQL's Hive support is interaction with Hive metastore, -which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary build of Spark SQL can be used to query different versions of Hive metastores, using the configuration described below. +which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary +build of Spark SQL can be used to query different versions of Hive metastores, using the configuration described below. +Note that independent of the version of Hive that is being used to talk to the metastore, internally Spark SQL +will compile against Hive 1.2.1 and use those classes for internal execution (serdes, UDFs, UDAFs, etc). -Internally, Spark SQL uses two Hive clients, one for executing native Hive commands like `SET` -and `DESCRIBE`, the other dedicated for communicating with Hive metastore. The former uses Hive -jars of version 0.13.1, which are bundled with Spark 1.4.0. The latter uses Hive jars of the -version specified by users. An isolated classloader is used here to avoid dependency conflicts. +The following options can be used to configure the version of Hive that is used to retrieve metadata:
Scala/JavaPythonMeaning
Scala/JavaAny LanguageMeaning
SaveMode.ErrorIfExists (default) "error" (default)
@@ -1685,7 +1700,7 @@ version specified by users. An isolated classloader is used here to avoid depend @@ -1696,12 +1711,16 @@ version specified by users. An isolated classloader is used here to avoid depend property can be one of three options:
  1. builtin
  2. - Use Hive 0.13.1, which is bundled with the Spark assembly jar when -Phive is + Use Hive 1.2.1, which is bundled with the Spark assembly jar when -Phive is enabled. When this option is chosen, spark.sql.hive.metastore.version must be - either 0.13.1 or not defined. + either 1.2.1 or not defined.
  3. maven
  4. - Use Hive jars of specified version downloaded from Maven repositories. -
  5. A classpath in the standard format for both Hive and Hadoop.
  6. + Use Hive jars of specified version downloaded from Maven repositories. This configuration + is not generally recommended for production deployments. +
  7. A classpath in the standard format for the JVM. This classpath must include all of Hive + and its dependencies, including the correct version of Hadoop. These jars only need to be + present on the driver, but if you are running in yarn cluster mode then you must ensure + they are packaged with you application.
@@ -2017,6 +2036,28 @@ options. # Migration Guide +## Upgrading From Spark SQL 1.4 to 1.5 + + - Optimized execution using manually managed memory (Tungsten) is now enabled by default, along with + code generation for expression evaluation. These features can both be disabled by setting + `spark.sql.tungsten.enabled` to `false. + - Parquet schema merging is no longer enabled by default. It can be re-enabled by setting + `spark.sql.parquet.mergeSchema` to `true`. + - Resolution of strings to columns in python now supports using dots (`.`) to qualify the column or + access nested values. For example `df['table.column.nestedField']`. However, this means that if + your column name contains any dots you must now escape them using backticks (e.g., ``table.`column.with.dots`.nested``). + - In-memory columnar storage partition pruning is on by default. It can be disabled by setting + `spark.sql.inMemoryColumnarStorage.partitionPruning` to `false`. + - Unlimited precision decimal columns are no longer supported, instead Spark SQL enforces a maximum + precision of 38. When inferring schema from `BigDecimal` objects, a precision of (38, 18) is now + used. When no precision is specified in DDL then the default remains `Decimal(10, 0)`. + - Timestamps are now stored at a precision of 1us, rather than 1ns + - In the `sql` dialect, floating point numbers are now parsed as decimal. HiveQL parsing remains + unchanged. + - The canonical name of SQL/DataFrame functions are now lower case (e.g. sum vs SUM). + - It has been determined that using the DirectOutputCommitter when speculation is enabled is unsafe + and thus this output committer will not be used when speculation is on, independent of configuration. + ## Upgrading from Spark SQL 1.3 to 1.4 #### DataFrame data reader/writer interface @@ -2038,7 +2079,8 @@ See the API docs for `SQLContext.read` ( #### DataFrame.groupBy retains grouping columns -Based on user feedback, we changed the default behavior of `DataFrame.groupBy().agg()` to retain the grouping columns in the resulting `DataFrame`. To keep the behavior in 1.3, set `spark.sql.retainGroupColumns` to `false`. +Based on user feedback, we changed the default behavior of `DataFrame.groupBy().agg()` to retain the +grouping columns in the resulting `DataFrame`. To keep the behavior in 1.3, set `spark.sql.retainGroupColumns` to `false`.
@@ -2175,7 +2217,7 @@ Python UDF registration is unchanged. When using DataTypes in Python you will need to construct them (i.e. `StringType()`) instead of referencing a singleton. -## Migration Guide for Shark User +## Migration Guide for Shark Users ### Scheduling To set a [Fair Scheduler](job-scheduling.html#fair-scheduler-pools) pool for a JDBC client session, @@ -2251,6 +2293,7 @@ Spark SQL supports the vast majority of Hive features, such as: * User defined functions (UDF) * User defined aggregation functions (UDAF) * User defined serialization formats (SerDes) +* Window functions * Joins * `JOIN` * `{LEFT|RIGHT|FULL} OUTER JOIN` @@ -2261,7 +2304,7 @@ Spark SQL supports the vast majority of Hive features, such as: * `SELECT col FROM ( SELECT a + b AS col from t1) t2` * Sampling * Explain -* Partitioned tables +* Partitioned tables including dynamic partition insertion * View * All Hive DDL Functions, including: * `CREATE TABLE` @@ -2323,8 +2366,9 @@ releases of Spark SQL. Hive can optionally merge the small files into fewer large files to avoid overflowing the HDFS metadata. Spark SQL does not support that. +# Reference -# Data Types +## Data Types Spark SQL and DataFrames support the following data types: @@ -2937,3 +2981,13 @@ from pyspark.sql.types import *
+## NaN Semantics + +There is specially handling for not-a-number (NaN) when dealing with `float` or `double` types that +does not exactly match standard floating point semantics. +Specifically: + + - NaN = NaN returns true. + - In aggregations all NaN values are grouped together. + - NaN is treated as a normal value in join keys. + - NaN values go last when in ascending order, larger than any other numeric value. From 84baa5e9b5edc8c55871fbed5057324450bf097f Mon Sep 17 00:00:00 2001 From: CodingCat Date: Thu, 27 Aug 2015 20:19:09 +0100 Subject: [PATCH 111/802] [SPARK-10315] remove document on spark.akka.failure-detector.threshold https://issues.apache.org/jira/browse/SPARK-10315 this parameter is not used any longer and there is some mistake in the current document , should be 'akka.remote.watch-failure-detector.threshold' Author: CodingCat Closes #8483 from CodingCat/SPARK_10315. --- docs/configuration.md | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 4a6e4dd05b661..77c5cbc7b3196 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -906,16 +906,6 @@ Apart from these, the following properties are also available, and may be useful #### Networking
Property NameDefaultMeaning
0.13.1 Version of the Hive metastore. Available - options are 0.12.0 and 0.13.1. Support for more versions is coming in the future. + options are 0.12.0 through 1.2.1.
- - - - - From 6185cdd2afcd492b77ff225b477b3624e3bc7bb2 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Thu, 27 Aug 2015 13:57:20 -0700 Subject: [PATCH 112/802] [SPARK-9901] User guide for RowMatrix Tall-and-skinny QR jira: https://issues.apache.org/jira/browse/SPARK-9901 The jira covers only the document update. I can further provide example code for QR (like the ones for SVD and PCA) in a separate PR. Author: Yuhao Yang Closes #8462 from hhbyyh/qrDoc. --- docs/mllib-data-types.md | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index f0e8d5495675d..065bf4727624f 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -337,7 +337,10 @@ limited by the integer range but it should be much smaller in practice.
A [`RowMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.RowMatrix) can be -created from an `RDD[Vector]` instance. Then we can compute its column summary statistics. +created from an `RDD[Vector]` instance. Then we can compute its column summary statistics and decompositions. +[QR decomposition](https://en.wikipedia.org/wiki/QR_decomposition) is of the form A = QR where Q is an orthogonal matrix and R is an upper triangular matrix. +For [singular value decomposition (SVD)](https://en.wikipedia.org/wiki/Singular_value_decomposition) and [principal component analysis (PCA)](https://en.wikipedia.org/wiki/Principal_component_analysis), please refer to [Dimensionality reduction](mllib-dimensionality-reduction.html). + {% highlight scala %} import org.apache.spark.mllib.linalg.Vector @@ -350,6 +353,9 @@ val mat: RowMatrix = new RowMatrix(rows) // Get its size. val m = mat.numRows() val n = mat.numCols() + +// QR decomposition +val qrResult = mat.tallSkinnyQR(true) {% endhighlight %}
@@ -370,6 +376,9 @@ RowMatrix mat = new RowMatrix(rows.rdd()); // Get its size. long m = mat.numRows(); long n = mat.numCols(); + +// QR decomposition +QRDecomposition result = mat.tallSkinnyQR(true); {% endhighlight %} From c94ecdfc5b3c0fe6c38a170dc2af9259354dc9e3 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Thu, 27 Aug 2015 15:33:43 -0700 Subject: [PATCH 113/802] [SPARK-9906] [ML] User guide for LogisticRegressionSummary User guide for LogisticRegression summaries Author: MechCoder Author: Manoj Kumar Author: Feynman Liang Closes #8197 from MechCoder/log_summary_user_guide. --- docs/ml-linear-methods.md | 149 ++++++++++++++++++++++++++++++++++---- 1 file changed, 133 insertions(+), 16 deletions(-) diff --git a/docs/ml-linear-methods.md b/docs/ml-linear-methods.md index 1ac83d94c9e81..2761aeb789621 100644 --- a/docs/ml-linear-methods.md +++ b/docs/ml-linear-methods.md @@ -23,20 +23,41 @@ displayTitle: ML - Linear Methods \]` -In MLlib, we implement popular linear methods such as logistic regression and linear least squares with L1 or L2 regularization. Refer to [the linear methods in mllib](mllib-linear-methods.html) for details. In `spark.ml`, we also include Pipelines API for [Elastic net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid of L1 and L2 regularization proposed in [this paper](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). Mathematically it is defined as a linear combination of the L1-norm and the L2-norm: +In MLlib, we implement popular linear methods such as logistic +regression and linear least squares with $L_1$ or $L_2$ regularization. +Refer to [the linear methods in mllib](mllib-linear-methods.html) for +details. In `spark.ml`, we also include Pipelines API for [Elastic +net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid +of $L_1$ and $L_2$ regularization proposed in [Zou et al, Regularization +and variable selection via the elastic +net](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). +Mathematically, it is defined as a convex combination of the $L_1$ and +the $L_2$ regularization terms: `\[ -\alpha \|\wv\|_1 + (1-\alpha) \frac{1}{2}\|\wv\|_2^2, \alpha \in [0, 1]. +\alpha~\lambda \|\wv\|_1 + (1-\alpha) \frac{\lambda}{2}\|\wv\|_2^2, \alpha \in [0, 1], \lambda \geq 0. \]` -By setting $\alpha$ properly, it contains both L1 and L2 regularization as special cases. For example, if a [linear regression](https://en.wikipedia.org/wiki/Linear_regression) model is trained with the elastic net parameter $\alpha$ set to $1$, it is equivalent to a [Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. On the other hand, if $\alpha$ is set to $0$, the trained model reduces to a [ridge regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. We implement Pipelines API for both linear regression and logistic regression with elastic net regularization. - -**Examples** +By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$ +regularization as special cases. For example, if a [linear +regression](https://en.wikipedia.org/wiki/Linear_regression) model is +trained with the elastic net parameter $\alpha$ set to $1$, it is +equivalent to a +[Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. +On the other hand, if $\alpha$ is set to $0$, the trained model reduces +to a [ridge +regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. +We implement Pipelines API for both linear regression and logistic +regression with elastic net regularization. + +## Example: Logistic Regression + +The following example shows how to train a logistic regression model +with elastic net regularization. `elasticNetParam` corresponds to +$\alpha$ and `regParam` corresponds to $\lambda$.
- {% highlight scala %} - import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.mllib.util.MLUtils @@ -53,15 +74,11 @@ val lrModel = lr.fit(training) // Print the weights and intercept for logistic regression println(s"Weights: ${lrModel.weights} Intercept: ${lrModel.intercept}") - {% endhighlight %} -
- {% highlight java %} - import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.LogisticRegressionModel; import org.apache.spark.mllib.regression.LabeledPoint; @@ -99,9 +116,7 @@ public class LogisticRegressionWithElasticNetExample {
- {% highlight python %} - from pyspark.ml.classification import LogisticRegression from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.util import MLUtils @@ -118,12 +133,114 @@ lrModel = lr.fit(training) print("Weights: " + str(lrModel.weights)) print("Intercept: " + str(lrModel.intercept)) {% endhighlight %} +
+The `spark.ml` implementation of logistic regression also supports +extracting a summary of the model over the training set. Note that the +predictions and metrics which are stored as `Dataframe` in +`BinaryLogisticRegressionSummary` are annotated `@transient` and hence +only available on the driver. + +
+ +
+ +[`LogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary) +provides a summary for a +[`LogisticRegressionModel`](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel). +Currently, only binary classification is supported and the +summary must be explicitly cast to +[`BinaryLogisticRegressionTrainingSummary`](api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary). +This will likely change when multiclass classification is supported. + +Continuing the earlier example: + +{% highlight scala %} +// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example +val trainingSummary = lrModel.summary + +// Obtain the loss per iteration. +val objectiveHistory = trainingSummary.objectiveHistory +objectiveHistory.foreach(loss => println(loss)) + +// Obtain the metrics useful to judge performance on test data. +// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a +// binary classification problem. +val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary] + +// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. +val roc = binarySummary.roc +roc.show() +roc.select("FPR").show() +println(binarySummary.areaUnderROC) + +// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with +// this selected threshold. +val fMeasure = binarySummary.fMeasureByThreshold +val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0) +val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure). + select("threshold").head().getDouble(0) +logReg.setThreshold(bestThreshold) +logReg.fit(logRegDataFrame) +{% endhighlight %}
-### Optimization +
+[`LogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html) +provides a summary for a +[`LogisticRegressionModel`](api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html). +Currently, only binary classification is supported and the +summary must be explicitly cast to +[`BinaryLogisticRegressionTrainingSummary`](api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html). +This will likely change when multiclass classification is supported. + +Continuing the earlier example: + +{% highlight java %} +// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example +LogisticRegressionTrainingSummary trainingSummary = logRegModel.summary(); + +// Obtain the loss per iteration. +double[] objectiveHistory = trainingSummary.objectiveHistory(); +for (double lossPerIteration : objectiveHistory) { + System.out.println(lossPerIteration); +} + +// Obtain the metrics useful to judge performance on test data. +// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a +// binary classification problem. +BinaryLogisticRegressionSummary binarySummary = (BinaryLogisticRegressionSummary) trainingSummary; + +// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. +DataFrame roc = binarySummary.roc(); +roc.show(); +roc.select("FPR").show(); +System.out.println(binarySummary.areaUnderROC()); + +// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with +// this selected threshold. +DataFrame fMeasure = binarySummary.fMeasureByThreshold(); +double maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0); +double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure)). + select("threshold").head().getDouble(0); +logReg.setThreshold(bestThreshold); +logReg.fit(logRegDataFrame); +{% endhighlight %} +
+ +
+Logistic regression model summary is not yet supported in Python. +
+ +
+ +# Optimization + +The optimization algorithm underlying the implementation is called +[Orthant-Wise Limited-memory +QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) +(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 +regularization and elastic net. -The optimization algorithm underlies the implementation is called [Orthant-Wise Limited-memory QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) -(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 regularization and elastic net. From 5bfe9e1111d9862084586549a7dc79476f67bab9 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Thu, 27 Aug 2015 16:10:37 -0700 Subject: [PATCH 114/802] [SPARK-9680] [MLLIB] [DOC] StopWordsRemovers user guide and Java compatibility test * Adds user guide for ml.feature.StopWordsRemovers, ran code examples on my machine * Cleans up scaladocs for public methods * Adds test for Java compatibility * Follow up Python user guide code example is tracked by SPARK-10249 Author: Feynman Liang Closes #8436 from feynmanliang/SPARK-10230. --- docs/ml-features.md | 102 +++++++++++++++++- .../ml/feature/JavaStopWordsRemoverSuite.java | 72 +++++++++++++ 2 files changed, 171 insertions(+), 3 deletions(-) create mode 100644 mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java diff --git a/docs/ml-features.md b/docs/ml-features.md index 62de4838981cb..89a9bad570446 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -306,15 +306,111 @@ regexTokenizer = RegexTokenizer(inputCol="sentence", outputCol="words", pattern= +## StopWordsRemover +[Stop words](https://en.wikipedia.org/wiki/Stop_words) are words which +should be excluded from the input, typically because the words appear +frequently and don't carry as much meaning. + +`StopWordsRemover` takes as input a sequence of strings (e.g. the output +of a [Tokenizer](ml-features.html#tokenizer)) and drops all the stop +words from the input sequences. The list of stopwords is specified by +the `stopWords` parameter. We provide [a list of stop +words](http://ir.dcs.gla.ac.uk/resources/linguistic_utils/stop_words) by +default, accessible by calling `getStopWords` on a newly instantiated +`StopWordsRemover` instance. -## $n$-gram +**Examples** -An [n-gram](https://en.wikipedia.org/wiki/N-gram) is a sequence of $n$ tokens (typically words) for some integer $n$. The `NGram` class can be used to transform input features into $n$-grams. +Assume that we have the following DataFrame with columns `id` and `raw`: -`NGram` takes as input a sequence of strings (e.g. the output of a [Tokenizer](ml-features.html#tokenizer). The parameter `n` is used to determine the number of terms in each $n$-gram. The output will consist of a sequence of $n$-grams where each $n$-gram is represented by a space-delimited string of $n$ consecutive words. If the input sequence contains fewer than `n` strings, no output is produced. +~~~~ + id | raw +----|---------- + 0 | [I, saw, the, red, baloon] + 1 | [Mary, had, a, little, lamb] +~~~~ + +Applying `StopWordsRemover` with `raw` as the input column and `filtered` as the output +column, we should get the following: + +~~~~ + id | raw | filtered +----|-----------------------------|-------------------- + 0 | [I, saw, the, red, baloon] | [saw, red, baloon] + 1 | [Mary, had, a, little, lamb]|[Mary, little, lamb] +~~~~ + +In `filtered`, the stop words "I", "the", "had", and "a" have been +filtered out.
+
+ +[`StopWordsRemover`](api/scala/index.html#org.apache.spark.ml.feature.StopWordsRemover) +takes an input column name, an output column name, a list of stop words, +and a boolean indicating if the matches should be case sensitive (false +by default). + +{% highlight scala %} +import org.apache.spark.ml.feature.StopWordsRemover + +val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") +val dataSet = sqlContext.createDataFrame(Seq( + (0, Seq("I", "saw", "the", "red", "baloon")), + (1, Seq("Mary", "had", "a", "little", "lamb")) +)).toDF("id", "raw") + +remover.transform(dataSet).show() +{% endhighlight %} +
+ +
+ +[`StopWordsRemover`](api/java/org/apache/spark/ml/feature/StopWordsRemover.html) +takes an input column name, an output column name, a list of stop words, +and a boolean indicating if the matches should be case sensitive (false +by default). + +{% highlight java %} +import java.util.Arrays; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.StopWordsRemover; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +StopWordsRemover remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered"); + +JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), + RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) +)); +StructType schema = new StructType(new StructField[] { + new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) +}); +DataFrame dataset = jsql.createDataFrame(rdd, schema); + +remover.transform(dataset).show(); +{% endhighlight %} +
+
+ +## $n$-gram + +An [n-gram](https://en.wikipedia.org/wiki/N-gram) is a sequence of $n$ tokens (typically words) for some integer $n$. The `NGram` class can be used to transform input features into $n$-grams. + +`NGram` takes as input a sequence of strings (e.g. the output of a [Tokenizer](ml-features.html#tokenizer)). The parameter `n` is used to determine the number of terms in each $n$-gram. The output will consist of a sequence of $n$-grams where each $n$-gram is represented by a space-delimited string of $n$ consecutive words. If the input sequence contains fewer than `n` strings, no output is produced. +
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java new file mode 100644 index 0000000000000..76cdd0fae84ab --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.Arrays; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + + +public class JavaStopWordsRemoverSuite { + + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaStopWordsRemoverSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void javaCompatibilityTest() { + StopWordsRemover remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered"); + + JavaRDD rdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), + RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) + )); + StructType schema = new StructType(new StructField[] { + new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) + }); + DataFrame dataset = jsql.createDataFrame(rdd, schema); + + remover.transform(dataset).collect(); + } +} From b3dd569ad40905f8861a547a1e25ed3ca8e1d272 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 27 Aug 2015 16:11:25 -0700 Subject: [PATCH 115/802] [SPARK-10287] [SQL] Fixes JSONRelation refreshing on read path https://issues.apache.org/jira/browse/SPARK-10287 After porting json to HadoopFsRelation, it seems hard to keep the behavior of picking up new files automatically for JSON. This PR removes this behavior, so JSON is consistent with others (ORC and Parquet). Author: Yin Huai Closes #8469 from yhuai/jsonRefresh. --- docs/sql-programming-guide.md | 6 ++++++ .../execution/datasources/json/JSONRelation.scala | 9 --------- .../org/apache/spark/sql/sources/interfaces.scala | 2 +- .../apache/spark/sql/sources/InsertSuite.scala | 15 --------------- 4 files changed, 7 insertions(+), 25 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 99fec6c7785af..e8eb88488ee24 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -2057,6 +2057,12 @@ options. - The canonical name of SQL/DataFrame functions are now lower case (e.g. sum vs SUM). - It has been determined that using the DirectOutputCommitter when speculation is enabled is unsafe and thus this output committer will not be used when speculation is on, independent of configuration. + - JSON data source will not automatically load new files that are created by other applications + (i.e. files that are not inserted to the dataset through Spark SQL). + For a JSON persistent table (i.e. the metadata of the table is stored in Hive Metastore), + users can use `REFRESH TABLE` SQL command or `HiveContext`'s `refreshTable` method + to include those new files to the table. For a DataFrame representing a JSON dataset, users need to recreate + the DataFrame and the new DataFrame will include new files. ## Upgrading from Spark SQL 1.3 to 1.4 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 114c8b211891e..ab8ca5f748f24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -111,15 +111,6 @@ private[sql] class JSONRelation( jsonSchema } - override private[sql] def buildScan( - requiredColumns: Array[String], - filters: Array[Filter], - inputPaths: Array[String], - broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { - refresh() - super.buildScan(requiredColumns, filters, inputPaths, broadcastedConf) - } - override def buildScan( requiredColumns: Array[String], filters: Array[Filter], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index b3b326fe612c7..dff726b33fc74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -562,7 +562,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio }) } - private[sql] def buildScan( + final private[sql] def buildScan( requiredColumns: Array[String], filters: Array[Filter], inputPaths: Array[String], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 78bd3e5582964..084d83f6e9bff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -167,21 +167,6 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { ) } - test("save directly to the path of a JSON table") { - caseInsensitiveContext.table("jt").selectExpr("a * 5 as a", "b") - .write.mode(SaveMode.Overwrite).json(path.toString) - checkAnswer( - sql("SELECT a, b FROM jsonTable"), - (1 to 10).map(i => Row(i * 5, s"str$i")) - ) - - caseInsensitiveContext.table("jt").write.mode(SaveMode.Overwrite).json(path.toString) - checkAnswer( - sql("SELECT a, b FROM jsonTable"), - (1 to 10).map(i => Row(i, s"str$i")) - ) - } - test("it is not allowed to write to a table while querying it.") { val message = intercept[AnalysisException] { sql( From 54cda0deb6bebf1470f16ba5bcc6c4fb842bdac1 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 27 Aug 2015 16:38:00 -0700 Subject: [PATCH 116/802] [SPARK-10321] sizeInBytes in HadoopFsRelation Having sizeInBytes in HadoopFsRelation to enable broadcast join. cc marmbrus Author: Davies Liu Closes #8490 from davies/sizeInByte. --- .../main/scala/org/apache/spark/sql/sources/interfaces.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index dff726b33fc74..7b030b7d73bd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -518,6 +518,8 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio override def inputFiles: Array[String] = cachedLeafStatuses().map(_.getPath.toString).toArray + override def sizeInBytes: Long = cachedLeafStatuses().map(_.getLen).sum + /** * Partition columns. Can be either defined by [[userDefinedPartitionColumns]] or automatically * discovered. Note that they should always be nullable. From 1f90c5e2198bcf49e115d97ec300c17c1be4dcb4 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 27 Aug 2015 19:38:53 -0700 Subject: [PATCH 117/802] [SPARK-8505] [SPARKR] Add settings to kick `lint-r` from `./dev/run-test.py` JoshRosen we'd like to check the SparkR source code with the `dev/lint-r` script on the Jenkins. I tried to incorporate the script into `dev/run-test.py`. Could you review it when you have time? shivaram I modified `dev/lint-r` and `dev/lint-r.R` to install lintr package into a local directory(`R/lib/`) and to exit with a lint status. Could you review it? - [[SPARK-8505] Add settings to kick `lint-r` from `./dev/run-test.py` - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-8505) Author: Yu ISHIKAWA Closes #7883 from yu-iskw/SPARK-8505. --- dev/lint-r | 11 +++++++++++ dev/lint-r.R | 12 +++++++----- dev/run-tests-codes.sh | 13 +++++++------ dev/run-tests-jenkins | 2 ++ dev/run-tests.py | 21 ++++++++++++++++++++- 5 files changed, 47 insertions(+), 12 deletions(-) diff --git a/dev/lint-r b/dev/lint-r index 7d5f4cd31153d..c15d57aad86da 100755 --- a/dev/lint-r +++ b/dev/lint-r @@ -28,3 +28,14 @@ if ! type "Rscript" > /dev/null; then fi `which Rscript` --vanilla "$SPARK_ROOT_DIR/dev/lint-r.R" "$SPARK_ROOT_DIR" | tee "$LINT_R_REPORT_FILE_NAME" + +NUM_LINES=`wc -l < "$LINT_R_REPORT_FILE_NAME"` +if [ "$NUM_LINES" = "0" ] ; then + lint_status=0 + echo "lintr checks passed." +else + lint_status=1 + echo "lintr checks failed." +fi + +exit "$lint_status" diff --git a/dev/lint-r.R b/dev/lint-r.R index 48bd6246096ae..999eef571b824 100644 --- a/dev/lint-r.R +++ b/dev/lint-r.R @@ -17,8 +17,14 @@ argv <- commandArgs(TRUE) SPARK_ROOT_DIR <- as.character(argv[1]) +LOCAL_LIB_LOC <- file.path(SPARK_ROOT_DIR, "R", "lib") -# Installs lintr from Github. +# Checks if SparkR is installed in a local directory. +if (! library(SparkR, lib.loc = LOCAL_LIB_LOC, logical.return = TRUE)) { + stop("You should install SparkR in a local directory with `R/install-dev.sh`.") +} + +# Installs lintr from Github in a local directory. # NOTE: The CRAN's version is too old to adapt to our rules. if ("lintr" %in% row.names(installed.packages()) == FALSE) { devtools::install_github("jimhester/lintr") @@ -27,9 +33,5 @@ if ("lintr" %in% row.names(installed.packages()) == FALSE) { library(lintr) library(methods) library(testthat) -if (! library(SparkR, lib.loc = file.path(SPARK_ROOT_DIR, "R", "lib"), logical.return = TRUE)) { - stop("You should install SparkR in a local directory with `R/install-dev.sh`.") -} - path.to.package <- file.path(SPARK_ROOT_DIR, "R", "pkg") lint_package(path.to.package, cache = FALSE) diff --git a/dev/run-tests-codes.sh b/dev/run-tests-codes.sh index f4b238e1b78a7..1f16790522e76 100644 --- a/dev/run-tests-codes.sh +++ b/dev/run-tests-codes.sh @@ -21,9 +21,10 @@ readonly BLOCK_GENERAL=10 readonly BLOCK_RAT=11 readonly BLOCK_SCALA_STYLE=12 readonly BLOCK_PYTHON_STYLE=13 -readonly BLOCK_DOCUMENTATION=14 -readonly BLOCK_BUILD=15 -readonly BLOCK_MIMA=16 -readonly BLOCK_SPARK_UNIT_TESTS=17 -readonly BLOCK_PYSPARK_UNIT_TESTS=18 -readonly BLOCK_SPARKR_UNIT_TESTS=19 +readonly BLOCK_R_STYLE=14 +readonly BLOCK_DOCUMENTATION=15 +readonly BLOCK_BUILD=16 +readonly BLOCK_MIMA=17 +readonly BLOCK_SPARK_UNIT_TESTS=18 +readonly BLOCK_PYSPARK_UNIT_TESTS=19 +readonly BLOCK_SPARKR_UNIT_TESTS=20 diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index f144c053046c5..39cf54f78104c 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -210,6 +210,8 @@ done failing_test="Scala style tests" elif [ "$test_result" -eq "$BLOCK_PYTHON_STYLE" ]; then failing_test="Python style tests" + elif [ "$test_result" -eq "$BLOCK_R_STYLE" ]; then + failing_test="R style tests" elif [ "$test_result" -eq "$BLOCK_DOCUMENTATION" ]; then failing_test="to generate documentation" elif [ "$test_result" -eq "$BLOCK_BUILD" ]; then diff --git a/dev/run-tests.py b/dev/run-tests.py index f689425ee40b6..4fd703a7c219f 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -209,6 +209,18 @@ def run_python_style_checks(): run_cmd([os.path.join(SPARK_HOME, "dev", "lint-python")]) +def run_sparkr_style_checks(): + set_title_and_block("Running R style checks", "BLOCK_R_STYLE") + + if which("R"): + # R style check should be executed after `install-dev.sh`. + # Since warnings about `no visible global function definition` appear + # without the installation. SEE ALSO: SPARK-9121. + run_cmd([os.path.join(SPARK_HOME, "dev", "lint-r")]) + else: + print("Ignoring SparkR style check as R was not found in PATH") + + def build_spark_documentation(): set_title_and_block("Building Spark Documentation", "BLOCK_DOCUMENTATION") os.environ["PRODUCTION"] = "1 jekyll build" @@ -387,7 +399,6 @@ def run_sparkr_tests(): set_title_and_block("Running SparkR tests", "BLOCK_SPARKR_UNIT_TESTS") if which("R"): - run_cmd([os.path.join(SPARK_HOME, "R", "install-dev.sh")]) run_cmd([os.path.join(SPARK_HOME, "R", "run-tests.sh")]) else: print("Ignoring SparkR tests as R was not found in PATH") @@ -438,6 +449,12 @@ def main(): if java_version.minor < 8: print("[warn] Java 8 tests will not run because JDK version is < 1.8.") + # install SparkR + if which("R"): + run_cmd([os.path.join(SPARK_HOME, "R", "install-dev.sh")]) + else: + print("Can't install SparkR as R is was not found in PATH") + if os.environ.get("AMPLAB_JENKINS"): # if we're on the Amplab Jenkins build servers setup variables # to reflect the environment settings @@ -485,6 +502,8 @@ def main(): run_scala_style_checks() if not changed_files or any(f.endswith(".py") for f in changed_files): run_python_style_checks() + if not changed_files or any(f.endswith(".R") for f in changed_files): + run_sparkr_style_checks() # determine if docs were changed and if we're inside the amplab environment # note - the below commented out until *all* Jenkins workers can get `jekyll` installed From 30734d45fbbb269437c062241a9161e198805a76 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Thu, 27 Aug 2015 21:44:06 -0700 Subject: [PATCH 118/802] [SPARK-9911] [DOC] [ML] Update Userguide for Evaluator I added a small note about the different types of evaluator and the metrics used. Author: MechCoder Closes #8304 from MechCoder/multiclass_evaluator. --- docs/ml-guide.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index de8fead3529e4..01bf5ee18e328 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -643,6 +643,13 @@ An important task in ML is *model selection*, or using data to find the best mod Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) class, which takes an `Estimator`, a set of `ParamMap`s, and an [`Evaluator`](api/scala/index.html#org.apache.spark.ml.Evaluator). `CrossValidator` begins by splitting the dataset into a set of *folds* which are used as separate training and test datasets; e.g., with `$k=3$` folds, `CrossValidator` will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. `CrossValidator` iterates through the set of `ParamMap`s. For each `ParamMap`, it trains the given `Estimator` and evaluates it using the given `Evaluator`. + +The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.RegressionEvaluator) +for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.BinaryClassificationEvaluator) +for binary data or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.MultiClassClassificationEvaluator) +for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the setMetric +method in each of these evaluators. + The `ParamMap` which produces the best evaluation metric (averaged over the `$k$` folds) is selected as the best model. `CrossValidator` finally fits the `Estimator` using the best `ParamMap` and the entire dataset. @@ -708,9 +715,12 @@ val pipeline = new Pipeline() // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. // This will allow us to jointly choose parameters for all Pipeline stages. // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +// Note that the evaluator here is a BinaryClassificationEvaluator and the default metric +// used is areaUnderROC. val crossval = new CrossValidator() .setEstimator(pipeline) .setEvaluator(new BinaryClassificationEvaluator) + // We use a ParamGridBuilder to construct a grid of parameters to search over. // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. @@ -831,9 +841,12 @@ Pipeline pipeline = new Pipeline() // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. // This will allow us to jointly choose parameters for all Pipeline stages. // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +// Note that the evaluator here is a BinaryClassificationEvaluator and the default metric +// used is areaUnderROC. CrossValidator crossval = new CrossValidator() .setEstimator(pipeline) .setEvaluator(new BinaryClassificationEvaluator()); + // We use a ParamGridBuilder to construct a grid of parameters to search over. // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. From af0e1249b1c881c0fa7a921fd21fd2c27214b980 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Thu, 27 Aug 2015 21:55:20 -0700 Subject: [PATCH 119/802] [SPARK-9905] [ML] [DOC] Adds LinearRegressionSummary user guide * Adds user guide for `LinearRegressionSummary` * Fixes unresolved issues in #8197 CC jkbradley mengxr Author: Feynman Liang Closes #8491 from feynmanliang/SPARK-9905. --- docs/ml-linear-methods.md | 140 ++++++++++++++++++++++++++++++++++---- 1 file changed, 127 insertions(+), 13 deletions(-) diff --git a/docs/ml-linear-methods.md b/docs/ml-linear-methods.md index 2761aeb789621..cdd9d4999fa1b 100644 --- a/docs/ml-linear-methods.md +++ b/docs/ml-linear-methods.md @@ -34,7 +34,7 @@ net](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). Mathematically, it is defined as a convex combination of the $L_1$ and the $L_2$ regularization terms: `\[ -\alpha~\lambda \|\wv\|_1 + (1-\alpha) \frac{\lambda}{2}\|\wv\|_2^2, \alpha \in [0, 1], \lambda \geq 0. +\alpha \left( \lambda \|\wv\|_1 \right) + (1-\alpha) \left( \frac{\lambda}{2}\|\wv\|_2^2 \right) , \alpha \in [0, 1], \lambda \geq 0 \]` By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$ regularization as special cases. For example, if a [linear @@ -95,7 +95,7 @@ public class LogisticRegressionWithElasticNetExample { SparkContext sc = new SparkContext(conf); SQLContext sql = new SQLContext(sc); - String path = "sample_libsvm_data.txt"; + String path = "data/mllib/sample_libsvm_data.txt"; // Load training data DataFrame training = sql.createDataFrame(MLUtils.loadLibSVMFile(sc, path).toJavaRDD(), LabeledPoint.class); @@ -103,7 +103,7 @@ public class LogisticRegressionWithElasticNetExample { LogisticRegression lr = new LogisticRegression() .setMaxIter(10) .setRegParam(0.3) - .setElasticNetParam(0.8) + .setElasticNetParam(0.8); // Fit the model LogisticRegressionModel lrModel = lr.fit(training); @@ -158,10 +158,12 @@ This will likely change when multiclass classification is supported. Continuing the earlier example: {% highlight scala %} +import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary + // Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example val trainingSummary = lrModel.summary -// Obtain the loss per iteration. +// Obtain the objective per iteration. val objectiveHistory = trainingSummary.objectiveHistory objectiveHistory.foreach(loss => println(loss)) @@ -173,17 +175,14 @@ val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. val roc = binarySummary.roc roc.show() -roc.select("FPR").show() println(binarySummary.areaUnderROC) -// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with -// this selected threshold. +// Set the model threshold to maximize F-Measure val fMeasure = binarySummary.fMeasureByThreshold val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0) val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure). select("threshold").head().getDouble(0) -logReg.setThreshold(bestThreshold) -logReg.fit(logRegDataFrame) +lrModel.setThreshold(bestThreshold) {% endhighlight %}
@@ -199,8 +198,12 @@ This will likely change when multiclass classification is supported. Continuing the earlier example: {% highlight java %} +import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary; +import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary; +import org.apache.spark.sql.functions; + // Extract the summary from the returned LogisticRegressionModel instance trained in the earlier example -LogisticRegressionTrainingSummary trainingSummary = logRegModel.summary(); +LogisticRegressionTrainingSummary trainingSummary = lrModel.summary(); // Obtain the loss per iteration. double[] objectiveHistory = trainingSummary.objectiveHistory(); @@ -222,20 +225,131 @@ System.out.println(binarySummary.areaUnderROC()); // Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with // this selected threshold. DataFrame fMeasure = binarySummary.fMeasureByThreshold(); -double maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0); +double maxFMeasure = fMeasure.select(functions.max("F-Measure")).head().getDouble(0); double bestThreshold = fMeasure.where(fMeasure.col("F-Measure").equalTo(maxFMeasure)). select("threshold").head().getDouble(0); -logReg.setThreshold(bestThreshold); -logReg.fit(logRegDataFrame); +lrModel.setThreshold(bestThreshold); {% endhighlight %}
+
Logistic regression model summary is not yet supported in Python.
+## Example: Linear Regression + +The interface for working with linear regression models and model +summaries is similar to the logistic regression case. The following +example demonstrates training an elastic net regularized linear +regression model and extracting model summary statistics. + +
+ +
+{% highlight scala %} +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.mllib.util.MLUtils + +// Load training data +val training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +val lr = new LinearRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8) + +// Fit the model +val lrModel = lr.fit(training) + +// Print the weights and intercept for linear regression +println(s"Weights: ${lrModel.weights} Intercept: ${lrModel.intercept}") + +// Summarize the model over the training set and print out some metrics +val trainingSummary = lrModel.summary +println(s"numIterations: ${trainingSummary.totalIterations}") +println(s"objectiveHistory: ${trainingSummary.objectiveHistory.toList}") +trainingSummary.residuals.show() +println(s"RMSE: ${trainingSummary.rootMeanSquaredError}") +println(s"r2: ${trainingSummary.r2}") +{% endhighlight %} +
+ +
+{% highlight java %} +import org.apache.spark.ml.regression.LinearRegression; +import org.apache.spark.ml.regression.LinearRegressionModel; +import org.apache.spark.ml.regression.LinearRegressionTrainingSummary; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +public class LinearRegressionWithElasticNetExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf() + .setAppName("Linear Regression with Elastic Net Example"); + + SparkContext sc = new SparkContext(conf); + SQLContext sql = new SQLContext(sc); + String path = "data/mllib/sample_libsvm_data.txt"; + + // Load training data + DataFrame training = sql.createDataFrame(MLUtils.loadLibSVMFile(sc, path).toJavaRDD(), LabeledPoint.class); + + LinearRegression lr = new LinearRegression() + .setMaxIter(10) + .setRegParam(0.3) + .setElasticNetParam(0.8); + + // Fit the model + LinearRegressionModel lrModel = lr.fit(training); + + // Print the weights and intercept for linear regression + System.out.println("Weights: " + lrModel.weights() + " Intercept: " + lrModel.intercept()); + + // Summarize the model over the training set and print out some metrics + LinearRegressionTrainingSummary trainingSummary = lrModel.summary(); + System.out.println("numIterations: " + trainingSummary.totalIterations()); + System.out.println("objectiveHistory: " + Vectors.dense(trainingSummary.objectiveHistory())); + trainingSummary.residuals().show(); + System.out.println("RMSE: " + trainingSummary.rootMeanSquaredError()); + System.out.println("r2: " + trainingSummary.r2()); + } +} +{% endhighlight %} +
+ +
+ +{% highlight python %} +from pyspark.ml.regression import LinearRegression +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import MLUtils + +# Load training data +training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + +lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) + +# Fit the model +lrModel = lr.fit(training) + +# Print the weights and intercept for linear regression +print("Weights: " + str(lrModel.weights)) +print("Intercept: " + str(lrModel.intercept)) + +# Linear regression model summary is not yet supported in Python. +{% endhighlight %} +
+ +
+ # Optimization The optimization algorithm underlying the implementation is called From 89b943438512fcfb239c268b43431397de46cbcf Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 27 Aug 2015 22:30:01 -0700 Subject: [PATCH 120/802] [SPARK-SQL] [MINOR] Fixes some typos in HiveContext Author: Cheng Lian Closes #8481 from liancheng/hive-context-typo. --- .../scala/org/apache/spark/sql/hive/HiveContext.scala | 8 ++++---- .../scala/org/apache/spark/sql/hive/test/TestHive.scala | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index c0a458fa9ab8d..2e791cea96b41 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -171,11 +171,11 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { * Overrides default Hive configurations to avoid breaking changes to Spark SQL users. * - allow SQL11 keywords to be used as identifiers */ - private[sql] def defaultOverides() = { + private[sql] def defaultOverrides() = { setConf(ConfVars.HIVE_SUPPORT_SQL11_RESERVED_KEYWORDS.varname, "false") } - defaultOverides() + defaultOverrides() /** * The copy of the Hive client that is used to retrieve metadata from the Hive MetaStore. @@ -190,8 +190,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { // into the isolated client loader val metadataConf = new HiveConf() - val defaltWarehouseLocation = metadataConf.get("hive.metastore.warehouse.dir") - logInfo("defalt warehouse location is " + defaltWarehouseLocation) + val defaultWarehouseLocation = metadataConf.get("hive.metastore.warehouse.dir") + logInfo("default warehouse location is " + defaultWarehouseLocation) // `configure` goes second to override other settings. val allConfig = metadataConf.asScala.map(e => e.getKey -> e.getValue).toMap ++ configure diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 572eaebe81ac2..57fea5d8db343 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -434,7 +434,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { case (k, v) => metadataHive.runSqlHive(s"SET $k=$v") } - defaultOverides() + defaultOverrides() runSqlHive("USE default") From 7583681e6b0824d7eed471dc4d8fa0b2addf9ffc Mon Sep 17 00:00:00 2001 From: noelsmith Date: Thu, 27 Aug 2015 23:59:30 -0700 Subject: [PATCH 121/802] [SPARK-10188] [PYSPARK] Pyspark CrossValidator with RMSE selects incorrect model * Added isLargerBetter() method to Pyspark Evaluator to match the Scala version. * JavaEvaluator delegates isLargerBetter() to underlying Scala object. * Added check for isLargerBetter() in CrossValidator to determine whether to use argmin or argmax. * Added test cases for where smaller is better (RMSE) and larger is better (R-Squared). (This contribution is my original work and that I license the work to the project under Sparks' open source license) Author: noelsmith Closes #8399 from noel-smith/pyspark-rmse-xval-fix. --- python/pyspark/ml/evaluation.py | 12 +++++ python/pyspark/ml/tests.py | 87 +++++++++++++++++++++++++++++++++ python/pyspark/ml/tuning.py | 6 ++- 3 files changed, 104 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 6b0a9ffde9f42..cb3b07947e488 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -66,6 +66,14 @@ def evaluate(self, dataset, params=None): else: raise ValueError("Params must be a param map but got %s." % type(params)) + def isLargerBetter(self): + """ + Indicates whether the metric returned by :py:meth:`evaluate` should be maximized + (True, default) or minimized (False). + A given evaluator may support multiple metrics which may be maximized or minimized. + """ + return True + @inherit_doc class JavaEvaluator(Evaluator, JavaWrapper): @@ -85,6 +93,10 @@ def _evaluate(self, dataset): self._transfer_params_to_java() return self._java_obj.evaluate(dataset._jdf) + def isLargerBetter(self): + self._transfer_params_to_java() + return self._java_obj.isLargerBetter() + @inherit_doc class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index c151d21fd661a..60e4237293adc 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -32,11 +32,14 @@ from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase from pyspark.sql import DataFrame, SQLContext +from pyspark.sql.functions import rand +from pyspark.ml.evaluation import RegressionEvaluator from pyspark.ml.param import Param, Params from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed from pyspark.ml.util import keyword_only from pyspark.ml import Estimator, Model, Pipeline, Transformer from pyspark.ml.feature import * +from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, CrossValidatorModel from pyspark.mllib.linalg import DenseVector @@ -264,5 +267,89 @@ def test_ngram(self): self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"]) +class HasInducedError(Params): + + def __init__(self): + super(HasInducedError, self).__init__() + self.inducedError = Param(self, "inducedError", + "Uniformly-distributed error added to feature") + + def getInducedError(self): + return self.getOrDefault(self.inducedError) + + +class InducedErrorModel(Model, HasInducedError): + + def __init__(self): + super(InducedErrorModel, self).__init__() + + def _transform(self, dataset): + return dataset.withColumn("prediction", + dataset.feature + (rand(0) * self.getInducedError())) + + +class InducedErrorEstimator(Estimator, HasInducedError): + + def __init__(self, inducedError=1.0): + super(InducedErrorEstimator, self).__init__() + self._set(inducedError=inducedError) + + def _fit(self, dataset): + model = InducedErrorModel() + self._copyValues(model) + return model + + +class CrossValidatorTests(PySparkTestCase): + + def test_fit_minimize_metric(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="rmse") + + grid = (ParamGridBuilder() + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) + .build()) + cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + bestModel = cvModel.bestModel + bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) + + self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), + "Best model should have zero induced error") + self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0") + + def test_fit_maximize_metric(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) + + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="r2") + + grid = (ParamGridBuilder() + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) + .build()) + cv = CrossValidator(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + bestModel = cvModel.bestModel + bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) + + self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), + "Best model should have zero induced error") + self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") + + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index dcfee6a3170ab..cae778869e9c5 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -223,7 +223,11 @@ def _fit(self, dataset): # TODO: duplicate evaluator to take extra params from input metric = eva.evaluate(model.transform(validation, epm[j])) metrics[j] += metric - bestIndex = np.argmax(metrics) + + if eva.isLargerBetter(): + bestIndex = np.argmax(metrics) + else: + bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) return CrossValidatorModel(bestModel) From 2f99c37273c1d82e2ba39476e4429ea4aaba7ec6 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Fri, 28 Aug 2015 00:37:50 -0700 Subject: [PATCH 122/802] [SPARK-10328] [SPARKR] Fix generic for na.omit S3 function is at https://stat.ethz.ch/R-manual/R-patched/library/stats/html/na.fail.html Author: Shivaram Venkataraman Author: Shivaram Venkataraman Author: Yu ISHIKAWA Closes #8495 from shivaram/na-omit-fix. --- R/pkg/R/DataFrame.R | 6 +++--- R/pkg/R/generics.R | 2 +- R/pkg/inst/tests/test_sparkSQL.R | 23 ++++++++++++++++++++++- dev/lint-r | 2 +- 4 files changed, 27 insertions(+), 6 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index dd8126aebf467..74de7c81e35a6 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1699,9 +1699,9 @@ setMethod("dropna", #' @name na.omit #' @export setMethod("na.omit", - signature(x = "DataFrame"), - function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { - dropna(x, how, minNonNulls, cols) + signature(object = "DataFrame"), + function(object, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + dropna(object, how, minNonNulls, cols) }) #' fillna diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index a829d46c1894c..b578b8789d2c5 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -413,7 +413,7 @@ setGeneric("dropna", #' @rdname nafunctions #' @export setGeneric("na.omit", - function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + function(object, ...) { standardGeneric("na.omit") }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 4b672e115f924..933b11c8ee7e2 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -1083,7 +1083,7 @@ test_that("describe() and summarize() on a DataFrame", { expect_equal(collect(stats2)[5, "age"], "30") }) -test_that("dropna() on a DataFrame", { +test_that("dropna() and na.omit() on a DataFrame", { df <- jsonFile(sqlContext, jsonPathNa) rows <- collect(df) @@ -1092,6 +1092,8 @@ test_that("dropna() on a DataFrame", { expected <- rows[!is.na(rows$name),] actual <- collect(dropna(df, cols = "name")) expect_identical(expected, actual) + actual <- collect(na.omit(df, cols = "name")) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age),] actual <- collect(dropna(df, cols = "age")) @@ -1101,48 +1103,67 @@ test_that("dropna() on a DataFrame", { expect_identical(expected$age, actual$age) expect_identical(expected$height, actual$height) expect_identical(expected$name, actual$name) + actual <- collect(na.omit(df, cols = "age")) expected <- rows[!is.na(rows$age) & !is.na(rows$height),] actual <- collect(dropna(df, cols = c("age", "height"))) expect_identical(expected, actual) + actual <- collect(na.omit(df, cols = c("age", "height"))) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] actual <- collect(dropna(df)) expect_identical(expected, actual) + actual <- collect(na.omit(df)) + expect_identical(expected, actual) # drop with how expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] actual <- collect(dropna(df)) expect_identical(expected, actual) + actual <- collect(na.omit(df)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name),] actual <- collect(dropna(df, "all")) expect_identical(expected, actual) + actual <- collect(na.omit(df, "all")) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] actual <- collect(dropna(df, "any")) expect_identical(expected, actual) + actual <- collect(na.omit(df, "any")) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) & !is.na(rows$height),] actual <- collect(dropna(df, "any", cols = c("age", "height"))) expect_identical(expected, actual) + actual <- collect(na.omit(df, "any", cols = c("age", "height"))) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) | !is.na(rows$height),] actual <- collect(dropna(df, "all", cols = c("age", "height"))) expect_identical(expected, actual) + actual <- collect(na.omit(df, "all", cols = c("age", "height"))) + expect_identical(expected, actual) # drop with threshold expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2,] actual <- collect(dropna(df, minNonNulls = 2, cols = c("age", "height"))) expect_identical(expected, actual) + actual <- collect(na.omit(df, minNonNulls = 2, cols = c("age", "height"))) + expect_identical(expected, actual) expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) + as.integer(!is.na(rows$name)) >= 3,] actual <- collect(dropna(df, minNonNulls = 3, cols = c("name", "age", "height"))) expect_identical(expected, actual) + actual <- collect(na.omit(df, minNonNulls = 3, cols = c("name", "age", "height"))) + expect_identical(expected, actual) }) test_that("fillna() on a DataFrame", { diff --git a/dev/lint-r b/dev/lint-r index c15d57aad86da..bfda0bca15eb7 100755 --- a/dev/lint-r +++ b/dev/lint-r @@ -29,7 +29,7 @@ fi `which Rscript` --vanilla "$SPARK_ROOT_DIR/dev/lint-r.R" "$SPARK_ROOT_DIR" | tee "$LINT_R_REPORT_FILE_NAME" -NUM_LINES=`wc -l < "$LINT_R_REPORT_FILE_NAME"` +NUM_LINES=`wc -l < "$LINT_R_REPORT_FILE_NAME" | awk '{print $1}'` if [ "$NUM_LINES" = "0" ] ; then lint_status=0 echo "lintr checks passed." From 4eeda8d472498acd40ef54723d1be9924a273d76 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Fri, 28 Aug 2015 00:50:26 -0700 Subject: [PATCH 123/802] [SPARK-10260] [ML] Add @Since annotation to ml.clustering ### JIRA [[SPARK-10260] Add Since annotation to ml.clustering - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-10260) Author: Yu ISHIKAWA Closes #8455 from yu-iskw/SPARK-10260. --- .../apache/spark/ml/clustering/KMeans.scala | 32 +++++++++++++++++-- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 47a18cdb31b53..f40ab71fb22a6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.clustering -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.param.{Param, Params, IntParam, ParamMap} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.{Identifiable, SchemaUtils} @@ -39,9 +39,11 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe * Set the number of clusters to create (k). Must be > 1. Default: 2. * @group param */ + @Since("1.5.0") final val k = new IntParam(this, "k", "number of clusters to create", (x: Int) => x > 1) /** @group getParam */ + @Since("1.5.0") def getK: Int = $(k) /** @@ -50,10 +52,12 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||. * @group expertParam */ + @Since("1.5.0") final val initMode = new Param[String](this, "initMode", "initialization algorithm", (value: String) => MLlibKMeans.validateInitMode(value)) /** @group expertGetParam */ + @Since("1.5.0") def getInitMode: String = $(initMode) /** @@ -61,10 +65,12 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe * setting -- the default of 5 is almost always enough. Must be > 0. Default: 5. * @group expertParam */ + @Since("1.5.0") final val initSteps = new IntParam(this, "initSteps", "number of steps for k-means||", (value: Int) => value > 0) /** @group expertGetParam */ + @Since("1.5.0") def getInitSteps: Int = $(initSteps) /** @@ -84,27 +90,32 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe * * @param parentModel a model trained by spark.mllib.clustering.KMeans. */ +@Since("1.5.0") @Experimental class KMeansModel private[ml] ( - override val uid: String, + @Since("1.5.0") override val uid: String, private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams { + @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { val copied = new KMeansModel(uid, parentModel) copyValues(copied, extra) } + @Since("1.5.0") override def transform(dataset: DataFrame): DataFrame = { val predictUDF = udf((vector: Vector) => predict(vector)) dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } private[clustering] def predict(features: Vector): Int = parentModel.predict(features) + @Since("1.5.0") def clusterCenters: Array[Vector] = parentModel.clusterCenters } @@ -114,8 +125,11 @@ class KMeansModel private[ml] ( * * @see [[http://dx.doi.org/10.14778/2180912.2180915 Bahmani et al., Scalable k-means++.]] */ +@Since("1.5.0") @Experimental -class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMeansParams { +class KMeans @Since("1.5.0") ( + @Since("1.5.0") override val uid: String) + extends Estimator[KMeansModel] with KMeansParams { setDefault( k -> 2, @@ -124,34 +138,45 @@ class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMean initSteps -> 5, tol -> 1e-4) + @Since("1.5.0") override def copy(extra: ParamMap): KMeans = defaultCopy(extra) + @Since("1.5.0") def this() = this(Identifiable.randomUID("kmeans")) /** @group setParam */ + @Since("1.5.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) /** @group setParam */ + @Since("1.5.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) /** @group setParam */ + @Since("1.5.0") def setK(value: Int): this.type = set(k, value) /** @group expertSetParam */ + @Since("1.5.0") def setInitMode(value: String): this.type = set(initMode, value) /** @group expertSetParam */ + @Since("1.5.0") def setInitSteps(value: Int): this.type = set(initSteps, value) /** @group setParam */ + @Since("1.5.0") def setMaxIter(value: Int): this.type = set(maxIter, value) /** @group setParam */ + @Since("1.5.0") def setTol(value: Double): this.type = set(tol, value) /** @group setParam */ + @Since("1.5.0") def setSeed(value: Long): this.type = set(seed, value) + @Since("1.5.0") override def fit(dataset: DataFrame): KMeansModel = { val rdd = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } @@ -167,6 +192,7 @@ class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMean copyValues(model) } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } From cc39803062119c1d14611dc227b9ed0ed1284d38 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 28 Aug 2015 09:32:23 +0100 Subject: [PATCH 124/802] [SPARK-10295] [CORE] Dynamic allocation in Mesos does not release when RDDs are cached Remove obsolete warning about dynamic allocation not working with cached RDDs See discussion in https://issues.apache.org/jira/browse/SPARK-10295 Author: Sean Owen Closes #8489 from srowen/SPARK-10295. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 5 ----- 1 file changed, 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index f3da04a7f55d0..738887076b0d1 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1590,11 +1590,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Register an RDD to be persisted in memory and/or disk storage */ private[spark] def persistRDD(rdd: RDD[_]) { - _executorAllocationManager.foreach { _ => - logWarning( - s"Dynamic allocation currently does not support cached RDDs. Cached data for RDD " + - s"${rdd.id} will be lost when executors are removed.") - } persistentRdds(rdd.id) = rdd } From 18294cd8710427076caa86bfac596de67089d57e Mon Sep 17 00:00:00 2001 From: Keiji Yoshida Date: Fri, 28 Aug 2015 09:36:50 +0100 Subject: [PATCH 125/802] Fix DynamodDB/DynamoDB typo in Kinesis Integration doc Fix DynamodDB/DynamoDB typo in Kinesis Integration doc Author: Keiji Yoshida Closes #8501 from yosssi/patch-1. --- docs/streaming-kinesis-integration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index a7bcaec6fcd84..238a911a9199f 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -91,7 +91,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m - Kinesis data processing is ordered per partition and occurs at-least once per message. - - Multiple applications can read from the same Kinesis stream. Kinesis will maintain the application-specific shard and checkpoint info in DynamodDB. + - Multiple applications can read from the same Kinesis stream. Kinesis will maintain the application-specific shard and checkpoint info in DynamoDB. - A single Kinesis stream shard is processed by one input DStream at a time. From 71a077f6c16c8816eae13341f645ba50d997f63d Mon Sep 17 00:00:00 2001 From: Dharmesh Kakadia Date: Fri, 28 Aug 2015 09:38:35 +0100 Subject: [PATCH 126/802] typo in comment Author: Dharmesh Kakadia Closes #8497 from dharmeshkakadia/patch-2. --- .../apache/spark/network/shuffle/protocol/RegisterExecutor.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java index cca8b17c4f129..167ef33104227 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java @@ -27,7 +27,7 @@ /** * Initial registration message between an executor and its local shuffle server. - * Returns nothing (empty bye array). + * Returns nothing (empty byte array). */ public class RegisterExecutor extends BlockTransferMessage { public final String appId; From 1502a0f6c5d2f85a331b29d3bf50002911ea393e Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 28 Aug 2015 09:32:52 -0500 Subject: [PATCH 127/802] [YARN] [MINOR] Avoid hard code port number in YarnShuffleService test Current port number is fixed as default (7337) in test, this will introduce port contention exception, better to change to a random number in unit test. squito , seems you're author of this unit test, mind taking a look at this fix? Thanks a lot. ``` [info] - executor state kept across NM restart *** FAILED *** (597 milliseconds) [info] org.apache.hadoop.service.ServiceStateException: java.net.BindException: Address already in use [info] at org.apache.hadoop.service.ServiceStateException.convert(ServiceStateException.java:59) [info] at org.apache.hadoop.service.AbstractService.init(AbstractService.java:172) [info] at org.apache.spark.network.yarn.YarnShuffleServiceSuite$$anonfun$1.apply$mcV$sp(YarnShuffleServiceSuite.scala:72) [info] at org.apache.spark.network.yarn.YarnShuffleServiceSuite$$anonfun$1.apply(YarnShuffleServiceSuite.scala:70) [info] at org.apache.spark.network.yarn.YarnShuffleServiceSuite$$anonfun$1.apply(YarnShuffleServiceSuite.scala:70) [info] at org.scalatest.Transformer$$anonfun$apply$1.apply$mcV$sp(Transformer.scala:22) [info] at org.scalatest.OutcomeOf$class.outcomeOf(OutcomeOf.scala:85) [info] at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104) [info] at org.scalatest.Transformer.apply(Transformer.scala:22) [info] at org.scalatest.Transformer.apply(Transformer.scala:20) [info] at org.scalatest.FunSuiteLike$$anon$1.apply(FunSuiteLike.scala:166) [info] at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:42) ... ``` Author: jerryshao Closes #8502 from jerryshao/avoid-hardcode-port. --- .../org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala index 2f22cbdbeac37..6aa8c814cd4f0 100644 --- a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala @@ -37,6 +37,7 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle") yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"), classOf[YarnShuffleService].getCanonicalName) + yarnConfig.setInt("spark.shuffle.service.port", 0) yarnConfig.get("yarn.nodemanager.local-dirs").split(",").foreach { dir => val d = new File(dir) From e2a843090cb031f6aa774f6d9c031a7f0f732ee1 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Fri, 28 Aug 2015 08:00:44 -0700 Subject: [PATCH 128/802] [SPARK-9890] [DOC] [ML] User guide for CountVectorizer jira: https://issues.apache.org/jira/browse/SPARK-9890 document with Scala and java examples Author: Yuhao Yang Closes #8487 from hhbyyh/cvDoc. --- docs/ml-features.md | 109 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) diff --git a/docs/ml-features.md b/docs/ml-features.md index 89a9bad570446..90654d1e5a248 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -211,6 +211,115 @@ for feature in result.select("result").take(3): +## CountVectorizer + +`CountVectorizer` and `CountVectorizerModel` aim to help convert a collection of text documents + to vectors of token counts. When an a-priori dictionary is not available, `CountVectorizer` can + be used as an `Estimator` to extract the vocabulary and generates a `CountVectorizerModel`. The + model produces sparse representations for the documents over the vocabulary, which can then be + passed to other algorithms like LDA. + + During the fitting process, `CountVectorizer` will select the top `vocabSize` words ordered by + term frequency across the corpus. An optional parameter "minDF" also affect the fitting process + by specifying the minimum number (or fraction if < 1.0) of documents a term must appear in to be + included in the vocabulary. + +**Examples** + +Assume that we have the following DataFrame with columns `id` and `texts`: + +~~~~ + id | texts +----|---------- + 0 | Array("a", "b", "c") + 1 | Array("a", "b", "b", "c", "a") +~~~~ + +each row in`texts` is a document of type Array[String]. +Invoking fit of `CountVectorizer` produces a `CountVectorizerModel` with vocabulary (a, b, c), +then the output column "vector" after transformation contains: + +~~~~ + id | texts | vector +----|---------------------------------|--------------- + 0 | Array("a", "b", "c") | (3,[0,1,2],[1.0,1.0,1.0]) + 1 | Array("a", "b", "b", "c", "a") | (3,[0,1,2],[2.0,2.0,1.0]) +~~~~ + +each vector represents the token counts of the document over the vocabulary. + +
+
+More details can be found in the API docs for +[CountVectorizer](api/scala/index.html#org.apache.spark.ml.feature.CountVectorizer) and +[CountVectorizerModel](api/scala/index.html#org.apache.spark.ml.feature.CountVectorizerModel). +{% highlight scala %} +import org.apache.spark.ml.feature.CountVectorizer +import org.apache.spark.mllib.util.CountVectorizerModel + +val df = sqlContext.createDataFrame(Seq( + (0, Array("a", "b", "c")), + (1, Array("a", "b", "b", "c", "a")) +)).toDF("id", "words") + +// fit a CountVectorizerModel from the corpus +val cvModel: CountVectorizerModel = new CountVectorizer() + .setInputCol("words") + .setOutputCol("features") + .setVocabSize(3) + .setMinDF(2) // a term must appear in more or equal to 2 documents to be included in the vocabulary + .fit(df) + +// alternatively, define CountVectorizerModel with a-priori vocabulary +val cvm = new CountVectorizerModel(Array("a", "b", "c")) + .setInputCol("words") + .setOutputCol("features") + +cvModel.transform(df).select("features").show() +{% endhighlight %} +
+ +
+More details can be found in the API docs for +[CountVectorizer](api/java/org/apache/spark/ml/feature/CountVectorizer.html) and +[CountVectorizerModel](api/java/org/apache/spark/ml/feature/CountVectorizerModel.html). +{% highlight java %} +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.feature.CountVectorizer; +import org.apache.spark.ml.feature.CountVectorizerModel; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.*; + +// Input data: Each row is a bag of words from a sentence or document. +JavaRDD jrdd = jsc.parallelize(Arrays.asList( + RowFactory.create(Arrays.asList("a", "b", "c")), + RowFactory.create(Arrays.asList("a", "b", "b", "c", "a")) +)); +StructType schema = new StructType(new StructField [] { + new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) +}); +DataFrame df = sqlContext.createDataFrame(jrdd, schema); + +// fit a CountVectorizerModel from the corpus +CountVectorizerModel cvModel = new CountVectorizer() + .setInputCol("text") + .setOutputCol("feature") + .setVocabSize(3) + .setMinDF(2) // a term must appear in more or equal to 2 documents to be included in the vocabulary + .fit(df); + +// alternatively, define CountVectorizerModel with a-priori vocabulary +CountVectorizerModel cvm = new CountVectorizerModel(new String[]{"a", "b", "c"}) + .setInputCol("text") + .setOutputCol("feature"); + +cvModel.transform(df).show(); +{% endhighlight %} +
+
+ # Feature Transformers ## Tokenizer From 499e8e154bdcc9d7b2f685b159e0ddb4eae48fe4 Mon Sep 17 00:00:00 2001 From: Luciano Resende Date: Fri, 28 Aug 2015 09:13:21 -0700 Subject: [PATCH 129/802] [SPARK-8952] [SPARKR] - Wrap normalizePath calls with suppressWarnings This is based on davies comment on SPARK-8952 which suggests to only call normalizePath() when path starts with '~' Author: Luciano Resende Closes #8343 from lresende/SPARK-8952. --- R/pkg/R/SQLContext.R | 4 ++-- R/pkg/R/sparkR.R | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 110117a18ccbc..1bc6445311473 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -201,7 +201,7 @@ setMethod("toDF", signature(x = "RDD"), jsonFile <- function(sqlContext, path) { # Allow the user to have a more flexible definiton of the text file path - path <- normalizePath(path) + path <- suppressWarnings(normalizePath(path)) # Convert a string vector of paths to a string containing comma separated paths path <- paste(path, collapse = ",") sdf <- callJMethod(sqlContext, "jsonFile", path) @@ -251,7 +251,7 @@ jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { # TODO: Implement saveasParquetFile and write examples for both parquetFile <- function(sqlContext, ...) { # Allow the user to have a more flexible definiton of the text file path - paths <- lapply(list(...), normalizePath) + paths <- lapply(list(...), function(x) suppressWarnings(normalizePath(x))) sdf <- callJMethod(sqlContext, "parquetFile", paths) dataFrame(sdf) } diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index e83104f116422..3c57a44db257d 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -160,7 +160,7 @@ sparkR.init <- function( }) if (nchar(sparkHome) != 0) { - sparkHome <- normalizePath(sparkHome) + sparkHome <- suppressWarnings(normalizePath(sparkHome)) } sparkEnvirMap <- new.env() From d3f87dc39480f075170817bbd00142967a938078 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 28 Aug 2015 11:51:42 -0700 Subject: [PATCH 130/802] [SPARK-10325] Override hashCode() for public Row This commit fixes an issue where the public SQL `Row` class did not override `hashCode`, causing it to violate the hashCode() + equals() contract. To fix this, I simply ported the `hashCode` implementation from the 1.4.x version of `Row`. Author: Josh Rosen Closes #8500 from JoshRosen/SPARK-10325 and squashes the following commits: 51ffea1 [Josh Rosen] Override hashCode() for public Row. --- .../src/main/scala/org/apache/spark/sql/Row.scala | 13 +++++++++++++ .../test/scala/org/apache/spark/sql/RowSuite.scala | 9 +++++++++ 2 files changed, 22 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index cfd9cb0e62598..ed2fdf9f2f7cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ +import scala.util.hashing.MurmurHash3 import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericRow @@ -410,6 +411,18 @@ trait Row extends Serializable { true } + override def hashCode: Int = { + // Using Scala's Seq hash code implementation. + var n = 0 + var h = MurmurHash3.seqSeed + val len = length + while (n < len) { + h = MurmurHash3.mix(h, apply(n).##) + n += 1 + } + MurmurHash3.finalizeHash(h, n) + } + /* ---------------------- utility methods for Scala ---------------------- */ /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 795d4e983f27e..77ccd6f775e50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -85,4 +85,13 @@ class RowSuite extends SparkFunSuite with SharedSQLContext { val r2 = Row(Double.NaN) assert(r1 === r2) } + + test("equals and hashCode") { + val r1 = Row("Hello") + val r2 = Row("Hello") + assert(r1 === r2) + assert(r1.hashCode() === r2.hashCode()) + val r3 = Row("World") + assert(r3.hashCode() != r1.hashCode()) + } } From c53c902fa9c458200245f919067b41dde9cd9418 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 28 Aug 2015 12:33:40 -0700 Subject: [PATCH 131/802] [SPARK-9284] [TESTS] Allow all tests to run without an assembly. This change aims at speeding up the dev cycle a little bit, by making sure that all tests behave the same w.r.t. where the code to be tested is loaded from. Namely, that means that tests don't rely on the assembly anymore, rather loading all needed classes from the build directories. The main change is to make sure all build directories (classes and test-classes) are added to the classpath of child processes when running tests. YarnClusterSuite required some custom code since the executors are run differently (i.e. not through the launcher library, like standalone and Mesos do). I also found a couple of tests that could leak a SparkContext on failure, and added code to handle those. With this patch, it's possible to run the following command from a clean source directory and have all tests pass: mvn -Pyarn -Phadoop-2.4 -Phive-thriftserver install Author: Marcelo Vanzin Closes #7629 from vanzin/SPARK-9284. --- bin/spark-class | 16 ++++---- .../spark/launcher/SparkLauncherSuite.java | 0 .../spark/broadcast/BroadcastSuite.scala | 10 ++++- .../launcher/AbstractCommandBuilder.java | 28 ++++++++------ pom.xml | 8 ++++ project/SparkBuild.scala | 2 + .../deploy/yarn/BaseYarnClusterSuite.scala | 37 +++++++++++++------ .../spark/deploy/yarn/YarnClusterSuite.scala | 20 ++++++++-- .../yarn/YarnShuffleIntegrationSuite.scala | 2 +- .../spark/launcher/TestClasspathBuilder.scala | 36 ++++++++++++++++++ 10 files changed, 122 insertions(+), 37 deletions(-) rename {launcher => core}/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java (100%) create mode 100644 yarn/src/test/scala/org/apache/spark/launcher/TestClasspathBuilder.scala diff --git a/bin/spark-class b/bin/spark-class index 2b59e5df5736f..e38e08dec40e4 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -43,17 +43,19 @@ else fi num_jars="$(ls -1 "$ASSEMBLY_DIR" | grep "^spark-assembly.*hadoop.*\.jar$" | wc -l)" -if [ "$num_jars" -eq "0" -a -z "$SPARK_ASSEMBLY_JAR" ]; then +if [ "$num_jars" -eq "0" -a -z "$SPARK_ASSEMBLY_JAR" -a "$SPARK_PREPEND_CLASSES" != "1" ]; then echo "Failed to find Spark assembly in $ASSEMBLY_DIR." 1>&2 echo "You need to build Spark before running this program." 1>&2 exit 1 fi -ASSEMBLY_JARS="$(ls -1 "$ASSEMBLY_DIR" | grep "^spark-assembly.*hadoop.*\.jar$" || true)" -if [ "$num_jars" -gt "1" ]; then - echo "Found multiple Spark assembly jars in $ASSEMBLY_DIR:" 1>&2 - echo "$ASSEMBLY_JARS" 1>&2 - echo "Please remove all but one jar." 1>&2 - exit 1 +if [ -d "$ASSEMBLY_DIR" ]; then + ASSEMBLY_JARS="$(ls -1 "$ASSEMBLY_DIR" | grep "^spark-assembly.*hadoop.*\.jar$" || true)" + if [ "$num_jars" -gt "1" ]; then + echo "Found multiple Spark assembly jars in $ASSEMBLY_DIR:" 1>&2 + echo "$ASSEMBLY_JARS" 1>&2 + echo "Please remove all but one jar." 1>&2 + exit 1 + fi fi SPARK_ASSEMBLY_JAR="${ASSEMBLY_DIR}/${ASSEMBLY_JARS}" diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java similarity index 100% rename from launcher/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java rename to core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 48e74f06f79b1..fb7a8ae3f9d41 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -310,8 +310,14 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { val _sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", broadcastConf) // Wait until all salves are up - _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 10000) - _sc + try { + _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 10000) + _sc + } catch { + case e: Throwable => + _sc.stop() + throw e + } } else { new SparkContext("local", "test", broadcastConf) } diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 5e793a5c48775..0a237ee73b670 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -169,9 +169,11 @@ List buildClassPath(String appClassPath) throws IOException { "streaming", "tools", "sql/catalyst", "sql/core", "sql/hive", "sql/hive-thriftserver", "yarn", "launcher"); if (prependClasses) { - System.err.println( - "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark classes ahead of " + - "assembly."); + if (!isTesting) { + System.err.println( + "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark classes ahead of " + + "assembly."); + } for (String project : projects) { addToClassPath(cp, String.format("%s/%s/target/scala-%s/classes", sparkHome, project, scala)); @@ -200,7 +202,7 @@ List buildClassPath(String appClassPath) throws IOException { // For the user code case, we fall back to looking for the Spark assembly under SPARK_HOME. // That duplicates some of the code in the shell scripts that look for the assembly, though. String assembly = getenv(ENV_SPARK_ASSEMBLY); - if (assembly == null && isEmpty(getenv("SPARK_TESTING"))) { + if (assembly == null && !isTesting) { assembly = findAssembly(); } addToClassPath(cp, assembly); @@ -215,12 +217,14 @@ List buildClassPath(String appClassPath) throws IOException { libdir = new File(sparkHome, "lib_managed/jars"); } - checkState(libdir.isDirectory(), "Library directory '%s' does not exist.", - libdir.getAbsolutePath()); - for (File jar : libdir.listFiles()) { - if (jar.getName().startsWith("datanucleus-")) { - addToClassPath(cp, jar.getAbsolutePath()); + if (libdir.isDirectory()) { + for (File jar : libdir.listFiles()) { + if (jar.getName().startsWith("datanucleus-")) { + addToClassPath(cp, jar.getAbsolutePath()); + } } + } else { + checkState(isTesting, "Library directory '%s' does not exist.", libdir.getAbsolutePath()); } addToClassPath(cp, getenv("HADOOP_CONF_DIR")); @@ -256,15 +260,15 @@ String getScalaVersion() { return scala; } String sparkHome = getSparkHome(); - File scala210 = new File(sparkHome, "assembly/target/scala-2.10"); - File scala211 = new File(sparkHome, "assembly/target/scala-2.11"); + File scala210 = new File(sparkHome, "launcher/target/scala-2.10"); + File scala211 = new File(sparkHome, "launcher/target/scala-2.11"); checkState(!scala210.isDirectory() || !scala211.isDirectory(), "Presence of build for both scala versions (2.10 and 2.11) detected.\n" + "Either clean one of them or set SPARK_SCALA_VERSION in your environment."); if (scala210.isDirectory()) { return "2.10"; } else { - checkState(scala211.isDirectory(), "Cannot find any assembly build directories."); + checkState(scala211.isDirectory(), "Cannot find any build directories."); return "2.11"; } } diff --git a/pom.xml b/pom.xml index 0716016523ee1..88ebceca769e9 100644 --- a/pom.xml +++ b/pom.xml @@ -1421,6 +1421,10 @@ org.apache.thrift libthrift + + org.mortbay.jetty + servlet-api + com.google.guava guava @@ -1892,6 +1896,8 @@ launched by the tests have access to the correct test-time classpath. --> ${test_classpath} + 1 + 1 ${test.java.home} @@ -1929,6 +1935,8 @@ launched by the tests have access to the correct test-time classpath. --> ${test_classpath} + 1 + 1 ${test.java.home} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index ea52bfd67944a..901cfa538d23e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -547,6 +547,8 @@ object TestSettings { envVars in Test ++= Map( "SPARK_DIST_CLASSPATH" -> (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":"), + "SPARK_PREPEND_CLASSES" -> "1", + "SPARK_TESTING" -> "1", "JAVA_HOME" -> sys.env.get("JAVA_HOME").getOrElse(sys.props("java.home"))), javaOptions in Test += s"-Djava.io.tmpdir=$testTempDir", javaOptions in Test += "-Dspark.test.home=" + sparkHome, diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala index b4f8049bff577..17c59ff06e0c1 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.yarn.server.MiniYARNCluster import org.scalatest.{BeforeAndAfterAll, Matchers} import org.apache.spark._ +import org.apache.spark.launcher.TestClasspathBuilder import org.apache.spark.util.Utils abstract class BaseYarnClusterSuite @@ -43,6 +44,9 @@ abstract class BaseYarnClusterSuite |log4j.appender.console.target=System.err |log4j.appender.console.layout=org.apache.log4j.PatternLayout |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + |log4j.logger.org.apache.hadoop=WARN + |log4j.logger.org.eclipse.jetty=WARN + |log4j.logger.org.spark-project.jetty=WARN """.stripMargin private var yarnCluster: MiniYARNCluster = _ @@ -51,8 +55,7 @@ abstract class BaseYarnClusterSuite private var hadoopConfDir: File = _ private var logConfDir: File = _ - - def yarnConfig: YarnConfiguration + def newYarnConfig(): YarnConfiguration override def beforeAll() { super.beforeAll() @@ -65,8 +68,14 @@ abstract class BaseYarnClusterSuite val logConfFile = new File(logConfDir, "log4j.properties") Files.write(LOG4J_CONF, logConfFile, UTF_8) + // Disable the disk utilization check to avoid the test hanging when people's disks are + // getting full. + val yarnConf = newYarnConfig() + yarnConf.set("yarn.nodemanager.disk-health-checker.max-disk-utilization-per-disk-percentage", + "100.0") + yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1) - yarnCluster.init(yarnConfig) + yarnCluster.init(yarnConf) yarnCluster.start() // There's a race in MiniYARNCluster in which start() may return before the RM has updated @@ -114,19 +123,23 @@ abstract class BaseYarnClusterSuite sparkArgs: Seq[String] = Nil, extraClassPath: Seq[String] = Nil, extraJars: Seq[String] = Nil, - extraConf: Map[String, String] = Map()): Unit = { + extraConf: Map[String, String] = Map(), + extraEnv: Map[String, String] = Map()): Unit = { val master = if (clientMode) "yarn-client" else "yarn-cluster" val props = new Properties() props.setProperty("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath()) - val childClasspath = logConfDir.getAbsolutePath() + - File.pathSeparator + - sys.props("java.class.path") + - File.pathSeparator + - extraClassPath.mkString(File.pathSeparator) - props.setProperty("spark.driver.extraClassPath", childClasspath) - props.setProperty("spark.executor.extraClassPath", childClasspath) + val testClasspath = new TestClasspathBuilder() + .buildClassPath( + logConfDir.getAbsolutePath() + + File.pathSeparator + + extraClassPath.mkString(File.pathSeparator)) + .asScala + .mkString(File.pathSeparator) + + props.setProperty("spark.driver.extraClassPath", testClasspath) + props.setProperty("spark.executor.extraClassPath", testClasspath) // SPARK-4267: make sure java options are propagated correctly. props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"") @@ -168,7 +181,7 @@ abstract class BaseYarnClusterSuite appArgs Utils.executeAndGetOutput(argv, - extraEnvironment = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath())) + extraEnvironment = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()) ++ extraEnv) } /** diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 5a4ea2ea2f4ff..b5a42fd6afd98 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -28,7 +28,9 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.scalatest.Matchers import org.apache.spark._ -import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, SparkListenerExecutorAdded} +import org.apache.spark.launcher.TestClasspathBuilder +import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, + SparkListenerExecutorAdded} import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.Utils @@ -39,7 +41,7 @@ import org.apache.spark.util.Utils */ class YarnClusterSuite extends BaseYarnClusterSuite { - override def yarnConfig: YarnConfiguration = new YarnConfiguration() + override def newYarnConfig(): YarnConfiguration = new YarnConfiguration() private val TEST_PYFILE = """ |import mod1, mod2 @@ -111,6 +113,17 @@ class YarnClusterSuite extends BaseYarnClusterSuite { val primaryPyFile = new File(tempDir, "test.py") Files.write(TEST_PYFILE, primaryPyFile, UTF_8) + // When running tests, let's not assume the user has built the assembly module, which also + // creates the pyspark archive. Instead, let's use PYSPARK_ARCHIVES_PATH to point at the + // needed locations. + val sparkHome = sys.props("spark.test.home"); + val pythonPath = Seq( + s"$sparkHome/python/lib/py4j-0.8.2.1-src.zip", + s"$sparkHome/python") + val extraEnv = Map( + "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator), + "PYTHONPATH" -> pythonPath.mkString(File.pathSeparator)) + val moduleDir = if (clientMode) { // In client-mode, .py files added with --py-files are not visible in the driver. @@ -130,7 +143,8 @@ class YarnClusterSuite extends BaseYarnClusterSuite { runSpark(clientMode, primaryPyFile.getAbsolutePath(), sparkArgs = Seq("--py-files", pyFiles), - appArgs = Seq(result.getAbsolutePath())) + appArgs = Seq(result.getAbsolutePath()), + extraEnv = extraEnv) checkResult(result) } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala index 5e8238822b90a..8d9c9b3004eda 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.network.yarn.{YarnShuffleService, YarnTestAccessor} */ class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { - override def yarnConfig: YarnConfiguration = { + override def newYarnConfig(): YarnConfiguration = { val yarnConfig = new YarnConfiguration() yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle") yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"), diff --git a/yarn/src/test/scala/org/apache/spark/launcher/TestClasspathBuilder.scala b/yarn/src/test/scala/org/apache/spark/launcher/TestClasspathBuilder.scala new file mode 100644 index 0000000000000..da9e8e21a26ae --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/launcher/TestClasspathBuilder.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.launcher + +import java.util.{List => JList, Map => JMap} + +/** + * Exposes AbstractCommandBuilder to the YARN tests, so that they can build classpaths the same + * way other cluster managers do. + */ +private[spark] class TestClasspathBuilder extends AbstractCommandBuilder { + + childEnv.put(CommandBuilderUtils.ENV_SPARK_HOME, sys.props("spark.test.home")) + + override def buildClassPath(extraCp: String): JList[String] = super.buildClassPath(extraCp) + + /** Not used by the YARN tests. */ + override def buildCommand(env: JMap[String, String]): JList[String] = + throw new UnsupportedOperationException() + +} From 45723214e694b9a440723e9504c562e6393709f3 Mon Sep 17 00:00:00 2001 From: Shuo Xiang Date: Fri, 28 Aug 2015 13:09:13 -0700 Subject: [PATCH 132/802] [SPARK-10336][example] fix not being able to set intercept in LR example `fitIntercept` is a command line option but not set in the main program. dbtsai Author: Shuo Xiang Closes #8510 from coderxiang/intercept and squashes the following commits: 57c9b7d [Shuo Xiang] fix not being able to set intercept in LR example --- .../org/apache/spark/examples/ml/LogisticRegressionExample.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala index 7682557127b51..8e3760ddb50a9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala @@ -136,6 +136,7 @@ object LogisticRegressionExample { .setElasticNetParam(params.elasticNetParam) .setMaxIter(params.maxIter) .setTol(params.tol) + .setFitIntercept(params.fitIntercept) stages += lor val pipeline = new Pipeline().setStages(stages.toArray) From 88032ecaf0455886aed7a66b30af80dae7f6cff7 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 28 Aug 2015 13:53:31 -0700 Subject: [PATCH 133/802] [SPARK-9671] [MLLIB] re-org user guide and add migration guide This PR updates the MLlib user guide and adds migration guide for 1.4->1.5. * merge migration guide for `spark.mllib` and `spark.ml` packages * remove dependency section from `spark.ml` guide * move the paragraph about `spark.mllib` and `spark.ml` to the top and recommend `spark.ml` * move Sam's talk to footnote to make the section focus on dependencies Minor changes to code examples and other wording will be in a separate PR. jkbradley srowen feynmanliang Author: Xiangrui Meng Closes #8498 from mengxr/SPARK-9671. --- docs/ml-guide.md | 52 ++------------ docs/mllib-guide.md | 119 ++++++++++++++++----------------- docs/mllib-migration-guides.md | 30 +++++++++ 3 files changed, 95 insertions(+), 106 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 01bf5ee18e328..ce53400b6ee56 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -21,19 +21,11 @@ title: Spark ML Programming Guide \]` -Spark 1.2 introduced a new package called `spark.ml`, which aims to provide a uniform set of -high-level APIs that help users create and tune practical machine learning pipelines. - -*Graduated from Alpha!* The Pipelines API is no longer an alpha component, although many elements of it are still `Experimental` or `DeveloperApi`. - -Note that we will keep supporting and adding features to `spark.mllib` along with the -development of `spark.ml`. -Users should be comfortable using `spark.mllib` features and expect more features coming. -Developers should contribute new algorithms to `spark.mllib` and can optionally contribute -to `spark.ml`. - -See the [Algorithm Guides section](#algorithm-guides) below for guides on sub-packages of `spark.ml`, including feature transformers unique to the Pipelines API, ensembles, and more. - +The `spark.ml` package aims to provide a uniform set of high-level APIs built on top of +[DataFrames](sql-programming-guide.html#dataframes) that help users create and tune practical +machine learning pipelines. +See the [Algorithm Guides section](#algorithm-guides) below for guides on sub-packages of +`spark.ml`, including feature transformers unique to the Pipelines API, ensembles, and more. **Table of Contents** @@ -171,7 +163,7 @@ This is useful if there are two algorithms with the `maxIter` parameter in a `Pi # Algorithm Guides -There are now several algorithms in the Pipelines API which are not in the lower-level MLlib API, so we link to documentation for them here. These algorithms are mostly feature transformers, which fit naturally into the `Transformer` abstraction in Pipelines, and ensembles, which fit naturally into the `Estimator` abstraction in the Pipelines. +There are now several algorithms in the Pipelines API which are not in the `spark.mllib` API, so we link to documentation for them here. These algorithms are mostly feature transformers, which fit naturally into the `Transformer` abstraction in Pipelines, and ensembles, which fit naturally into the `Estimator` abstraction in the Pipelines. **Pipelines API Algorithm Guides** @@ -880,35 +872,3 @@ jsc.stop(); - -# Dependencies - -Spark ML currently depends on MLlib and has the same dependencies. -Please see the [MLlib Dependencies guide](mllib-guide.html#dependencies) for more info. - -Spark ML also depends upon Spark SQL, but the relevant parts of Spark SQL do not bring additional dependencies. - -# Migration Guide - -## From 1.3 to 1.4 - -Several major API changes occurred, including: -* `Param` and other APIs for specifying parameters -* `uid` unique IDs for Pipeline components -* Reorganization of certain classes -Since the `spark.ml` API was an Alpha Component in Spark 1.3, we do not list all changes here. - -However, now that `spark.ml` is no longer an Alpha Component, we will provide details on any API changes for future releases. - -## From 1.2 to 1.3 - -The main API changes are from Spark SQL. We list the most important changes here: - -* The old [SchemaRDD](http://spark.apache.org/docs/1.2.1/api/scala/index.html#org.apache.spark.sql.SchemaRDD) has been replaced with [DataFrame](api/scala/index.html#org.apache.spark.sql.DataFrame) with a somewhat modified API. All algorithms in Spark ML which used to use SchemaRDD now use DataFrame. -* In Spark 1.2, we used implicit conversions from `RDD`s of `LabeledPoint` into `SchemaRDD`s by calling `import sqlContext._` where `sqlContext` was an instance of `SQLContext`. These implicits have been moved, so we now call `import sqlContext.implicits._`. -* Java APIs for SQL have also changed accordingly. Please see the examples above and the [Spark SQL Programming Guide](sql-programming-guide.html) for details. - -Other changes were in `LogisticRegression`: - -* The `scoreCol` output column (with default value "score") was renamed to be `probabilityCol` (with default value "probability"). The type was originally `Double` (for the probability of class 1.0), but it is now `Vector` (for the probability of each class, to support multiclass classification in the future). -* In Spark 1.2, `LogisticRegressionModel` did not include an intercept. In Spark 1.3, it includes an intercept; however, it will always be 0.0 since it uses the default settings for [spark.mllib.LogisticRegressionWithLBFGS](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS). The option to use an intercept will be added in the future. diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 6330c977552d1..876dcfd40ed7b 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -5,21 +5,28 @@ displayTitle: Machine Learning Library (MLlib) Guide description: MLlib machine learning library overview for Spark SPARK_VERSION_SHORT --- -MLlib is Spark's scalable machine learning library consisting of common learning algorithms and utilities, -including classification, regression, clustering, collaborative -filtering, dimensionality reduction, as well as underlying optimization primitives. -Guides for individual algorithms are listed below. +MLlib is Spark's machine learning (ML) library. +Its goal is to make practical machine learning scalable and easy. +It consists of common learning algorithms and utilities, including classification, regression, +clustering, collaborative filtering, dimensionality reduction, as well as lower-level optimization +primitives and higher-level pipeline APIs. -The API is divided into 2 parts: +It divides into two packages: -* [The original `spark.mllib` API](mllib-guide.html#mllib-types-algorithms-and-utilities) is the primary API. -* [The "Pipelines" `spark.ml` API](mllib-guide.html#sparkml-high-level-apis-for-ml-pipelines) is a higher-level API for constructing ML workflows. +* [`spark.mllib`](mllib-guide.html#mllib-types-algorithms-and-utilities) contains the original API + built on top of RDDs. +* [`spark.ml`](mllib-guide.html#sparkml-high-level-apis-for-ml-pipelines) provides higher-level API + built on top of DataFrames for constructing ML pipelines. -We list major functionality from both below, with links to detailed guides. +Using `spark.ml` is recommended because with DataFrames the API is more versatile and flexible. +But we will keep supporting `spark.mllib` along with the development of `spark.ml`. +Users should be comfortable using `spark.mllib` features and expect more features coming. +Developers should contribute new algorithms to `spark.ml` if they fit the ML pipeline concept well, +e.g., feature extractors and transformers. -# MLlib types, algorithms and utilities +We list major functionality from both below, with links to detailed guides. -This lists functionality included in `spark.mllib`, the main MLlib API. +# spark.mllib: data types, algorithms, and utilities * [Data types](mllib-data-types.html) * [Basic statistics](mllib-statistics.html) @@ -56,71 +63,63 @@ This lists functionality included in `spark.mllib`, the main MLlib API. * [limited-memory BFGS (L-BFGS)](mllib-optimization.html#limited-memory-bfgs-l-bfgs) * [PMML model export](mllib-pmml-model-export.html) -MLlib is under active development. -The APIs marked `Experimental`/`DeveloperApi` may change in future releases, -and the migration guide below will explain all changes between releases. - # spark.ml: high-level APIs for ML pipelines -Spark 1.2 introduced a new package called `spark.ml`, which aims to provide a uniform set of -high-level APIs that help users create and tune practical machine learning pipelines. - -*Graduated from Alpha!* The Pipelines API is no longer an alpha component, although many elements of it are still `Experimental` or `DeveloperApi`. - -Note that we will keep supporting and adding features to `spark.mllib` along with the -development of `spark.ml`. -Users should be comfortable using `spark.mllib` features and expect more features coming. -Developers should contribute new algorithms to `spark.mllib` and can optionally contribute -to `spark.ml`. - -Guides for `spark.ml` include: +**[spark.ml programming guide](ml-guide.html)** provides an overview of the Pipelines API and major +concepts. It also contains sections on using algorithms within the Pipelines API, for example: -* **[spark.ml programming guide](ml-guide.html)**: overview of the Pipelines API and major concepts -* Guides on using algorithms within the Pipelines API: - * [Feature transformers](ml-features.html), including a few not in the lower-level `spark.mllib` API - * [Decision trees](ml-decision-tree.html) - * [Ensembles](ml-ensembles.html) - * [Linear methods](ml-linear-methods.html) +* [Feature Extraction, Transformation, and Selection](ml-features.html) +* [Decision Trees for Classification and Regression](ml-decision-tree.html) +* [Ensembles](ml-ensembles.html) +* [Linear methods with elastic net regularization](ml-linear-methods.html) +* [Multilayer perceptron classifier](ml-ann.html) # Dependencies -MLlib uses the linear algebra package -[Breeze](http://www.scalanlp.org/), which depends on -[netlib-java](https://github.com/fommil/netlib-java) for optimised -numerical processing. If natives are not available at runtime, you -will see a warning message and a pure JVM implementation will be used -instead. +MLlib uses the linear algebra package [Breeze](http://www.scalanlp.org/), which depends on +[netlib-java](https://github.com/fommil/netlib-java) for optimised numerical processing. +If natives libraries[^1] are not available at runtime, you will see a warning message and a pure JVM +implementation will be used instead. -To learn more about the benefits and background of system optimised -natives, you may wish to watch Sam Halliday's ScalaX talk on -[High Performance Linear Algebra in Scala](http://fommil.github.io/scalax14/#/)). +Due to licensing issues with runtime proprietary binaries, we do not include `netlib-java`'s native +proxies by default. +To configure `netlib-java` / Breeze to use system optimised binaries, include +`com.github.fommil.netlib:all:1.1.2` (or build Spark with `-Pnetlib-lgpl`) as a dependency of your +project and read the [netlib-java](https://github.com/fommil/netlib-java) documentation for your +platform's additional installation instructions. -Due to licensing issues with runtime proprietary binaries, we do not -include `netlib-java`'s native proxies by default. To configure -`netlib-java` / Breeze to use system optimised binaries, include -`com.github.fommil.netlib:all:1.1.2` (or build Spark with -`-Pnetlib-lgpl`) as a dependency of your project and read the -[netlib-java](https://github.com/fommil/netlib-java) documentation for -your platform's additional installation instructions. +To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.4 or newer. -To use MLlib in Python, you will need [NumPy](http://www.numpy.org) -version 1.4 or newer. +[^1]: To learn more about the benefits and background of system optimised natives, you may wish to + watch Sam Halliday's ScalaX talk on [High Performance Linear Algebra in Scala](http://fommil.github.io/scalax14/#/). ---- +# Migration guide -# Migration Guide +MLlib is under active development. +The APIs marked `Experimental`/`DeveloperApi` may change in future releases, +and the migration guide below will explain all changes between releases. + +## From 1.4 to 1.5 -For the `spark.ml` package, please see the [spark.ml Migration Guide](ml-guide.html#migration-guide). +In the `spark.mllib` package, there are no break API changes but several behavior changes: -## From 1.3 to 1.4 +* [SPARK-9005](https://issues.apache.org/jira/browse/SPARK-9005): + `RegressionMetrics.explainedVariance` returns the average regression sum of squares. +* [SPARK-8600](https://issues.apache.org/jira/browse/SPARK-8600): `NaiveBayesModel.labels` become + sorted. +* [SPARK-3382](https://issues.apache.org/jira/browse/SPARK-3382): `GradientDescent` has a default + convergence tolerance `1e-3`, and hence iterations might end earlier than 1.4. -In the `spark.mllib` package, there were several breaking changes, but all in `DeveloperApi` or `Experimental` APIs: +In the `spark.ml` package, there exists one break API change and one behavior change: -* Gradient-Boosted Trees - * *(Breaking change)* The signature of the [`Loss.gradient`](api/scala/index.html#org.apache.spark.mllib.tree.loss.Loss) method was changed. This is only an issues for users who wrote their own losses for GBTs. - * *(Breaking change)* The `apply` and `copy` methods for the case class [`BoostingStrategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.BoostingStrategy) have been changed because of a modification to the case class fields. This could be an issue for users who use `BoostingStrategy` to set GBT parameters. -* *(Breaking change)* The return value of [`LDA.run`](api/scala/index.html#org.apache.spark.mllib.clustering.LDA) has changed. It now returns an abstract class `LDAModel` instead of the concrete class `DistributedLDAModel`. The object of type `LDAModel` can still be cast to the appropriate concrete type, which depends on the optimization algorithm. +* [SPARK-9268](https://issues.apache.org/jira/browse/SPARK-9268): Java's varargs support is removed + from `Params.setDefault` due to a + [Scala compiler bug](https://issues.scala-lang.org/browse/SI-9013). +* [SPARK-10097](https://issues.apache.org/jira/browse/SPARK-10097): `Evaluator.isLargerBetter` is + added to indicate metric ordering. Metrics like RMSE no longer flip signs as in 1.4. -## Previous Spark Versions +## Previous Spark versions Earlier migration guides are archived [on this page](mllib-migration-guides.html). + +--- diff --git a/docs/mllib-migration-guides.md b/docs/mllib-migration-guides.md index 8df68d81f3c78..774b85d1f773a 100644 --- a/docs/mllib-migration-guides.md +++ b/docs/mllib-migration-guides.md @@ -7,6 +7,25 @@ description: MLlib migration guides from before Spark SPARK_VERSION_SHORT The migration guide for the current Spark version is kept on the [MLlib Programming Guide main page](mllib-guide.html#migration-guide). +## From 1.3 to 1.4 + +In the `spark.mllib` package, there were several breaking changes, but all in `DeveloperApi` or `Experimental` APIs: + +* Gradient-Boosted Trees + * *(Breaking change)* The signature of the [`Loss.gradient`](api/scala/index.html#org.apache.spark.mllib.tree.loss.Loss) method was changed. This is only an issues for users who wrote their own losses for GBTs. + * *(Breaking change)* The `apply` and `copy` methods for the case class [`BoostingStrategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.BoostingStrategy) have been changed because of a modification to the case class fields. This could be an issue for users who use `BoostingStrategy` to set GBT parameters. +* *(Breaking change)* The return value of [`LDA.run`](api/scala/index.html#org.apache.spark.mllib.clustering.LDA) has changed. It now returns an abstract class `LDAModel` instead of the concrete class `DistributedLDAModel`. The object of type `LDAModel` can still be cast to the appropriate concrete type, which depends on the optimization algorithm. + +In the `spark.ml` package, several major API changes occurred, including: + +* `Param` and other APIs for specifying parameters +* `uid` unique IDs for Pipeline components +* Reorganization of certain classes + +Since the `spark.ml` API was an alpha component in Spark 1.3, we do not list all changes here. +However, since 1.4 `spark.ml` is no longer an alpha component, we will provide details on any API +changes for future releases. + ## From 1.2 to 1.3 In the `spark.mllib` package, there were several breaking changes. The first change (in `ALS`) is the only one in a component not marked as Alpha or Experimental. @@ -23,6 +42,17 @@ In the `spark.mllib` package, there were several breaking changes. The first ch * In linear regression (including Lasso and ridge regression), the squared loss is now divided by 2. So in order to produce the same result as in 1.2, the regularization parameter needs to be divided by 2 and the step size needs to be multiplied by 2. +In the `spark.ml` package, the main API changes are from Spark SQL. We list the most important changes here: + +* The old [SchemaRDD](http://spark.apache.org/docs/1.2.1/api/scala/index.html#org.apache.spark.sql.SchemaRDD) has been replaced with [DataFrame](api/scala/index.html#org.apache.spark.sql.DataFrame) with a somewhat modified API. All algorithms in Spark ML which used to use SchemaRDD now use DataFrame. +* In Spark 1.2, we used implicit conversions from `RDD`s of `LabeledPoint` into `SchemaRDD`s by calling `import sqlContext._` where `sqlContext` was an instance of `SQLContext`. These implicits have been moved, so we now call `import sqlContext.implicits._`. +* Java APIs for SQL have also changed accordingly. Please see the examples above and the [Spark SQL Programming Guide](sql-programming-guide.html) for details. + +Other changes were in `LogisticRegression`: + +* The `scoreCol` output column (with default value "score") was renamed to be `probabilityCol` (with default value "probability"). The type was originally `Double` (for the probability of class 1.0), but it is now `Vector` (for the probability of each class, to support multiclass classification in the future). +* In Spark 1.2, `LogisticRegressionModel` did not include an intercept. In Spark 1.3, it includes an intercept; however, it will always be 0.0 since it uses the default settings for [spark.mllib.LogisticRegressionWithLBFGS](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS). The option to use an intercept will be added in the future. + ## From 1.1 to 1.2 The only API changes in MLlib v1.2 are in From bb7f35239385ec74b5ee69631b5480fbcee253e4 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 28 Aug 2015 14:38:20 -0700 Subject: [PATCH 134/802] [SPARK-10323] [SQL] fix nullability of In/InSet/ArrayContain After this PR, In/InSet/ArrayContain will return null if value is null, instead of false. They also will return null even if there is a null in the set/array. Author: Davies Liu Closes #8492 from davies/fix_in. --- .../expressions/collectionOperations.scala | 62 ++++++++-------- .../sql/catalyst/expressions/predicates.scala | 71 +++++++++++++++---- .../sql/catalyst/optimizer/Optimizer.scala | 6 -- .../CollectionFunctionsSuite.scala | 12 +++- .../catalyst/expressions/PredicateSuite.scala | 49 +++++++++---- .../optimizer/ConstantFoldingSuite.scala | 21 +----- .../spark/sql/DataFrameFunctionsSuite.scala | 14 ++-- 7 files changed, 138 insertions(+), 97 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 646afa4047d84..7b8c5b723ded4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -19,9 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{ - CodegenFallback, CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, CodegenFallback, GeneratedExpressionCode} import org.apache.spark.sql.types._ /** @@ -145,46 +143,42 @@ case class ArrayContains(left: Expression, right: Expression) } } - override def nullable: Boolean = false + override def nullable: Boolean = { + left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull + } - override def eval(input: InternalRow): Boolean = { - val arr = left.eval(input) - if (arr == null) { - false - } else { - val value = right.eval(input) - if (value == null) { - false - } else { - arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => - if (v == value) return true - ) - false + override def nullSafeEval(arr: Any, value: Any): Any = { + var hasNull = false + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == null) { + hasNull = true + } else if (v == value) { + return true } + ) + if (hasNull) { + null + } else { + false } } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val arrGen = left.gen(ctx) - val elementGen = right.gen(ctx) - val i = ctx.freshName("i") - val getValue = ctx.getValue(arrGen.primitive, right.dataType, i) - s""" - ${arrGen.code} - boolean ${ev.isNull} = false; - boolean ${ev.primitive} = false; - if (!${arrGen.isNull}) { - ${elementGen.code} - if (!${elementGen.isNull}) { - for (int $i = 0; $i < ${arrGen.primitive}.numElements(); $i ++) { - if (${ctx.genEqual(right.dataType, elementGen.primitive, getValue)}) { - ${ev.primitive} = true; - break; - } - } + nullSafeCodeGen(ctx, ev, (arr, value) => { + val i = ctx.freshName("i") + val getValue = ctx.getValue(arr, right.dataType, i) + s""" + for (int $i = 0; $i < $arr.numElements(); $i ++) { + if ($arr.isNullAt($i)) { + ${ev.isNull} = true; + } else if (${ctx.genEqual(right.dataType, value, getValue)}) { + ${ev.isNull} = false; + ${ev.primitive} = true; + break; } } """ + }) } override def prettyName: String = "array_contains" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index fe7dffb815987..65706dba7d975 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,11 +17,9 @@ package org.apache.spark.sql.catalyst.expressions -import scala.collection.mutable - -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -103,6 +101,8 @@ case class Not(child: Expression) case class In(value: Expression, list: Seq[Expression]) extends Predicate with ImplicitCastInputTypes { + require(list != null, "list should not be null") + override def inputTypes: Seq[AbstractDataType] = value.dataType +: list.map(_.dataType) override def checkInputDataTypes(): TypeCheckResult = { @@ -116,12 +116,31 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate override def children: Seq[Expression] = value +: list - override def nullable: Boolean = false // TODO: Figure out correct nullability semantics of IN. + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}" override def eval(input: InternalRow): Any = { val evaluatedValue = value.eval(input) - list.exists(e => e.eval(input) == evaluatedValue) + if (evaluatedValue == null) { + null + } else { + var hasNull = false + list.foreach { e => + val v = e.eval(input) + if (v == evaluatedValue) { + return true + } else if (v == null) { + hasNull = true + } + } + if (hasNull) { + null + } else { + false + } + } } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { @@ -131,7 +150,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate s""" if (!${ev.primitive}) { ${x.code} - if (${ctx.genEqual(value.dataType, valueGen.primitive, x.primitive)}) { + if (${x.isNull}) { + ${ev.isNull} = true; + } else if (${ctx.genEqual(value.dataType, valueGen.primitive, x.primitive)}) { + ${ev.isNull} = false; ${ev.primitive} = true; } } @@ -139,8 +161,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate s""" ${valueGen.code} boolean ${ev.primitive} = false; - boolean ${ev.isNull} = false; - $listCode + boolean ${ev.isNull} = ${valueGen.isNull}; + if (!${ev.isNull}) { + $listCode + } """ } } @@ -151,11 +175,22 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate */ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with Predicate { - override def nullable: Boolean = false // TODO: Figure out correct nullability semantics of IN. + require(hset != null, "hset could not be null") + override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}" - override def eval(input: InternalRow): Any = { - hset.contains(child.eval(input)) + @transient private[this] lazy val hasNull: Boolean = hset.contains(null) + + override def nullable: Boolean = child.nullable || hasNull + + protected override def nullSafeEval(value: Any): Any = { + if (hset.contains(value)) { + true + } else if (hasNull) { + null + } else { + false + } } def getHSet(): Set[Any] = hset @@ -166,12 +201,20 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with val childGen = child.gen(ctx) ctx.references += this val hsetTerm = ctx.freshName("hset") + val hasNullTerm = ctx.freshName("hasNull") ctx.addMutableState(setName, hsetTerm, s"$hsetTerm = (($InSetName)expressions[${ctx.references.size - 1}]).getHSet();") + ctx.addMutableState("boolean", hasNullTerm, s"$hasNullTerm = $hsetTerm.contains(null);") s""" ${childGen.code} - boolean ${ev.isNull} = false; - boolean ${ev.primitive} = $hsetTerm.contains(${childGen.primitive}); + boolean ${ev.isNull} = ${childGen.isNull}; + boolean ${ev.primitive} = false; + if (!${ev.isNull}) { + ${ev.primitive} = $hsetTerm.contains(${childGen.primitive}); + if (!${ev.primitive} && $hasNullTerm) { + ${ev.isNull} = true; + } + } """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 854463dd11c74..a430000bef653 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -395,12 +395,6 @@ object ConstantFolding extends Rule[LogicalPlan] { // Fold expressions that are foldable. case e if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType) - - // Fold "literal in (item1, item2, ..., literal, ...)" into true directly. - case In(Literal(v, _), list) if list.exists { - case Literal(candidate, _) if candidate == v => true - case _ => false - } => Literal.create(true, BooleanType) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala index 95f0e38212a1a..a3e81888dfd0d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -70,14 +70,20 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) val a2 = Literal.create(Seq(null), ArrayType(LongType)) + val a3 = Literal.create(null, ArrayType(StringType)) checkEvaluation(ArrayContains(a0, Literal(1)), true) checkEvaluation(ArrayContains(a0, Literal(0)), false) - checkEvaluation(ArrayContains(a0, Literal(null)), false) + checkEvaluation(ArrayContains(a0, Literal.create(null, IntegerType)), null) checkEvaluation(ArrayContains(a1, Literal("")), true) - checkEvaluation(ArrayContains(a1, Literal(null)), false) + checkEvaluation(ArrayContains(a1, Literal("a")), null) + checkEvaluation(ArrayContains(a1, Literal.create(null, StringType)), null) - checkEvaluation(ArrayContains(a2, Literal(null)), false) + checkEvaluation(ArrayContains(a2, Literal(1L)), null) + checkEvaluation(ArrayContains(a2, Literal.create(null, LongType)), null) + + checkEvaluation(ArrayContains(a3, Literal("")), null) + checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 54c04faddb477..03e7611fce8ff 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.immutable.HashSet import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.types._ @@ -119,6 +118,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { (null, null, null) :: Nil) test("IN") { + checkEvaluation(In(Literal.create(null, IntegerType), Seq(Literal(1), Literal(2))), null) + checkEvaluation(In(Literal.create(null, IntegerType), Seq(Literal.create(null, IntegerType))), + null) + checkEvaluation(In(Literal(1), Seq(Literal.create(null, IntegerType))), null) + checkEvaluation(In(Literal(1), Seq(Literal(1), Literal.create(null, IntegerType))), true) + checkEvaluation(In(Literal(2), Seq(Literal(1), Literal.create(null, IntegerType))), null) checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) @@ -126,14 +131,18 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))), true) - checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true) + val ns = Literal.create(null, StringType) + checkEvaluation(In(ns, Seq(Literal("1"), Literal("2"))), null) + checkEvaluation(In(ns, Seq(ns)), null) + checkEvaluation(In(Literal("a"), Seq(ns)), null) + checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"), ns)), true) checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true) checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false) val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType, LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType) primitiveTypes.map { t => - val dataGen = RandomDataGenerator.forType(t, nullable = false).get + val dataGen = RandomDataGenerator.forType(t, nullable = true).get val inputData = Seq.fill(10) { val value = dataGen.apply() value match { @@ -142,9 +151,17 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { case _ => value } } - val input = inputData.map(Literal(_)) - checkEvaluation(In(input(0), input.slice(1, 10)), - inputData.slice(1, 10).contains(inputData(0))) + val input = inputData.map(Literal.create(_, t)) + val expected = if (inputData(0) == null) { + null + } else if (inputData.slice(1, 10).contains(inputData(0))) { + true + } else if (inputData.slice(1, 10).contains(null)) { + null + } else { + false + } + checkEvaluation(In(input(0), input.slice(1, 10)), expected) } } @@ -158,15 +175,15 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(InSet(one, hS), true) checkEvaluation(InSet(two, hS), true) checkEvaluation(InSet(two, nS), true) - checkEvaluation(InSet(nl, nS), true) checkEvaluation(InSet(three, hS), false) - checkEvaluation(InSet(three, nS), false) - checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true) + checkEvaluation(InSet(three, nS), null) + checkEvaluation(InSet(nl, hS), null) + checkEvaluation(InSet(nl, nS), null) val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType, LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType) primitiveTypes.map { t => - val dataGen = RandomDataGenerator.forType(t, nullable = false).get + val dataGen = RandomDataGenerator.forType(t, nullable = true).get val inputData = Seq.fill(10) { val value = dataGen.apply() value match { @@ -176,8 +193,16 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } } val input = inputData.map(Literal(_)) - checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet), - inputData.slice(1, 10).contains(inputData(0))) + val expected = if (inputData(0) == null) { + null + } else if (inputData.slice(1, 10).contains(inputData(0))) { + true + } else if (inputData.slice(1, 10).contains(null)) { + null + } else { + false + } + checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet), expected) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index ec3b2f1edfa05..e67606288f514 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -250,29 +250,14 @@ class ConstantFoldingSuite extends PlanTest { } test("Constant folding test: Fold In(v, list) into true or false") { - var originalQuery = + val originalQuery = testRelation .select('a) .where(In(Literal(1), Seq(Literal(1), Literal(2)))) - var optimized = Optimize.execute(originalQuery.analyze) - - var correctAnswer = - testRelation - .select('a) - .where(Literal(true)) - .analyze - - comparePlans(optimized, correctAnswer) - - originalQuery = - testRelation - .select('a) - .where(In(Literal(1), Seq(Literal(1), 'a.attr))) - - optimized = Optimize.execute(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) - correctAnswer = + val correctAnswer = testRelation .select('a) .where(Literal(true)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 9d965258e389d..3a3f19af1473b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -366,10 +366,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_contains(a, 1)"), Seq(Row(true), Row(false)) ) - checkAnswer( - df.select(array_contains(array(lit(2), lit(null)), 1)), - Seq(Row(false), Row(false)) - ) // In hive, this errors because null has no type information intercept[AnalysisException] { @@ -382,15 +378,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_contains(null, 1)") } - // In hive, if either argument has a matching type has a null value, return false, even if - // the first argument array contains a null and the second argument is null checkAnswer( - df.selectExpr("array_contains(array(array(1), null)[1], 1)"), - Seq(Row(false), Row(false)) + df.selectExpr("array_contains(array(array(1), null)[0], 1)"), + Seq(Row(true), Row(true)) ) checkAnswer( - df.selectExpr("array_contains(array(0, null), array(1, null)[1])"), - Seq(Row(false), Row(false)) + df.selectExpr("array_contains(array(1, null), array(1, null)[0])"), + Seq(Row(true), Row(true)) ) } } From 2a4e00ca4d4e7a148b4ff8ce0ad1c6d517cee55f Mon Sep 17 00:00:00 2001 From: felixcheung Date: Fri, 28 Aug 2015 18:35:01 -0700 Subject: [PATCH 135/802] [SPARK-9803] [SPARKR] Add subset and transform + tests Add subset and transform Also reorganize `[` & `[[` to subset instead of select Note: for transform, transform is very similar to mutate. Spark doesn't seem to replace existing column with the name in mutate (ie. `mutate(df, age = df$age + 2)` - returned DataFrame has 2 columns with the same name 'age'), so therefore not doing that for now in transform. Though it is clearly stated it should replace column with matching name (should I open a JIRA for mutate/transform?) Author: felixcheung Closes #8503 from felixcheung/rsubset_transform. --- R/pkg/NAMESPACE | 2 + R/pkg/R/DataFrame.R | 70 +++++++++++++++++++++++++------- R/pkg/R/generics.R | 10 ++++- R/pkg/inst/tests/test_sparkSQL.R | 20 ++++++++- 4 files changed, 85 insertions(+), 17 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 5286c01986204..9d39630706436 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -69,9 +69,11 @@ exportMethods("arrange", "selectExpr", "show", "showDF", + "subset", "summarize", "summary", "take", + "transform", "unionAll", "unique", "unpersist", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 74de7c81e35a6..8a00238b41d60 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -987,7 +987,7 @@ setMethod("$<-", signature(x = "DataFrame"), setClassUnion("numericOrcharacter", c("numeric", "character")) -#' @rdname select +#' @rdname subset #' @name [[ setMethod("[[", signature(x = "DataFrame", i = "numericOrcharacter"), function(x, i) { @@ -998,7 +998,7 @@ setMethod("[[", signature(x = "DataFrame", i = "numericOrcharacter"), getColumn(x, i) }) -#' @rdname select +#' @rdname subset #' @name [ setMethod("[", signature(x = "DataFrame", i = "missing"), function(x, i, j, ...) { @@ -1012,7 +1012,7 @@ setMethod("[", signature(x = "DataFrame", i = "missing"), select(x, j) }) -#' @rdname select +#' @rdname subset #' @name [ setMethod("[", signature(x = "DataFrame", i = "Column"), function(x, i, j, ...) { @@ -1020,12 +1020,43 @@ setMethod("[", signature(x = "DataFrame", i = "Column"), # https://stat.ethz.ch/R-manual/R-devel/library/base/html/Extract.data.frame.html filtered <- filter(x, i) if (!missing(j)) { - filtered[, j] + filtered[, j, ...] } else { filtered } }) +#' Subset +#' +#' Return subsets of DataFrame according to given conditions +#' @param x A DataFrame +#' @param subset A logical expression to filter on rows +#' @param select expression for the single Column or a list of columns to select from the DataFrame +#' @return A new DataFrame containing only the rows that meet the condition with selected columns +#' @export +#' @rdname subset +#' @name subset +#' @aliases [ +#' @family subsetting functions +#' @examples +#' \dontrun{ +#' # Columns can be selected using `[[` and `[` +#' df[[2]] == df[["age"]] +#' df[,2] == df[,"age"] +#' df[,c("name", "age")] +#' # Or to filter rows +#' df[df$age > 20,] +#' # DataFrame can be subset on both rows and Columns +#' df[df$name == "Smith", c(1,2)] +#' df[df$age %in% c(19, 30), 1:2] +#' subset(df, df$age %in% c(19, 30), 1:2) +#' subset(df, df$age %in% c(19), select = c(1,2)) +#' } +setMethod("subset", signature(x = "DataFrame"), + function(x, subset, select, ...) { + x[subset, select, ...] + }) + #' Select #' #' Selects a set of columns with names or Column expressions. @@ -1034,6 +1065,8 @@ setMethod("[", signature(x = "DataFrame", i = "Column"), #' @return A new DataFrame with selected columns #' @export #' @rdname select +#' @name select +#' @family subsetting functions #' @examples #' \dontrun{ #' select(df, "*") @@ -1041,15 +1074,8 @@ setMethod("[", signature(x = "DataFrame", i = "Column"), #' select(df, df$name, df$age + 1) #' select(df, c("col1", "col2")) #' select(df, list(df$name, df$age + 1)) -#' # Columns can also be selected using `[[` and `[` -#' df[[2]] == df[["age"]] -#' df[,2] == df[,"age"] -#' df[,c("name", "age")] #' # Similar to R data frames columns can also be selected using `$` #' df$age -#' # It can also be subset on rows and Columns -#' df[df$name == "Smith", c(1,2)] -#' df[df$age %in% c(19, 30), 1:2] #' } setMethod("select", signature(x = "DataFrame", col = "character"), function(x, col, ...) { @@ -1121,7 +1147,7 @@ setMethod("selectExpr", #' @return A DataFrame with the new column added. #' @rdname withColumn #' @name withColumn -#' @aliases mutate +#' @aliases mutate transform #' @export #' @examples #'\dontrun{ @@ -1141,11 +1167,12 @@ setMethod("withColumn", #' #' Return a new DataFrame with the specified columns added. #' -#' @param x A DataFrame +#' @param .data A DataFrame #' @param col a named argument of the form name = col #' @return A new DataFrame with the new columns added. #' @rdname withColumn #' @name mutate +#' @aliases withColumn transform #' @export #' @examples #'\dontrun{ @@ -1155,10 +1182,12 @@ setMethod("withColumn", #' df <- jsonFile(sqlContext, path) #' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2) #' names(newDF) # Will contain newCol, newCol2 +#' newDF2 <- transform(df, newCol = df$col1 / 5, newCol2 = df$col1 * 2) #' } setMethod("mutate", - signature(x = "DataFrame"), - function(x, ...) { + signature(.data = "DataFrame"), + function(.data, ...) { + x <- .data cols <- list(...) stopifnot(length(cols) > 0) stopifnot(class(cols[[1]]) == "Column") @@ -1173,6 +1202,16 @@ setMethod("mutate", do.call(select, c(x, x$"*", cols)) }) +#' @export +#' @rdname withColumn +#' @name transform +#' @aliases withColumn mutate +setMethod("transform", + signature(`_data` = "DataFrame"), + function(`_data`, ...) { + mutate(`_data`, ...) + }) + #' WithColumnRenamed #' #' Rename an existing column in a DataFrame. @@ -1300,6 +1339,7 @@ setMethod("orderBy", #' @return A DataFrame containing only the rows that meet the condition. #' @rdname filter #' @name filter +#' @family subsetting functions #' @export #' @examples #'\dontrun{ diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index b578b8789d2c5..43dd8d283ab6b 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -467,7 +467,7 @@ setGeneric("merge") #' @rdname withColumn #' @export -setGeneric("mutate", function(x, ...) {standardGeneric("mutate") }) +setGeneric("mutate", function(.data, ...) {standardGeneric("mutate") }) #' @rdname arrange #' @export @@ -507,6 +507,10 @@ setGeneric("saveAsTable", function(df, tableName, source, mode, ...) { standardGeneric("saveAsTable") }) +#' @rdname withColumn +#' @export +setGeneric("transform", function(`_data`, ...) {standardGeneric("transform") }) + #' @rdname write.df #' @export setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") }) @@ -531,6 +535,10 @@ setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr") #' @export setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) +# @rdname subset +# @export +setGeneric("subset", function(x, subset, select, ...) { standardGeneric("subset") }) + #' @rdname agg #' @export setGeneric("summarize", function(x,...) { standardGeneric("summarize") }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 933b11c8ee7e2..0da5e38654732 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -612,6 +612,10 @@ test_that("subsetting", { df5 <- df[df$age %in% c(19), c(1,2)] expect_equal(count(df5), 1) expect_equal(columns(df5), c("name", "age")) + + df6 <- subset(df, df$age %in% c(30), c(1,2)) + expect_equal(count(df6), 1) + expect_equal(columns(df6), c("name", "age")) }) test_that("selectExpr() on a DataFrame", { @@ -1028,7 +1032,7 @@ test_that("withColumn() and withColumnRenamed()", { expect_equal(columns(newDF2)[1], "newerAge") }) -test_that("mutate(), rename() and names()", { +test_that("mutate(), transform(), rename() and names()", { df <- jsonFile(sqlContext, jsonPath) newDF <- mutate(df, newAge = df$age + 2) expect_equal(length(columns(newDF)), 3) @@ -1042,6 +1046,20 @@ test_that("mutate(), rename() and names()", { names(newDF2) <- c("newerName", "evenNewerAge") expect_equal(length(names(newDF2)), 2) expect_equal(names(newDF2)[1], "newerName") + + transformedDF <- transform(df, newAge = -df$age, newAge2 = df$age / 2) + expect_equal(length(columns(transformedDF)), 4) + expect_equal(columns(transformedDF)[3], "newAge") + expect_equal(columns(transformedDF)[4], "newAge2") + expect_equal(first(filter(transformedDF, transformedDF$name == "Andy"))$newAge, -30) + + # test if transform on local data frames works + # ensure the proper signature is used - otherwise this will fail to run + attach(airquality) + result <- transform(Ozone, logOzone = log(Ozone)) + expect_equal(nrow(result), 153) + expect_equal(ncol(result), 2) + detach(airquality) }) test_that("write.df() on DataFrame and works with parquetFile", { From e8ea5bafee9ca734edf62021145d0c2d5491cba8 Mon Sep 17 00:00:00 2001 From: martinzapletal Date: Fri, 28 Aug 2015 21:03:48 -0700 Subject: [PATCH 136/802] [SPARK-9910] [ML] User guide for train validation split Author: martinzapletal Closes #8377 from zapletal-martin/SPARK-9910. --- docs/ml-guide.md | 117 ++++++++++++++++++ .../ml/JavaTrainValidationSplitExample.java | 90 ++++++++++++++ .../ml/TrainValidationSplitExample.scala | 80 ++++++++++++ 3 files changed, 287 insertions(+) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala diff --git a/docs/ml-guide.md b/docs/ml-guide.md index ce53400b6ee56..a92a285f3af85 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -872,3 +872,120 @@ jsc.stop(); + +## Example: Model Selection via Train Validation Split +In addition to `CrossValidator` Spark also offers `TrainValidationSplit` for hyper-parameter tuning. +`TrainValidationSplit` only evaluates each combination of parameters once as opposed to k times in + case of `CrossValidator`. It is therefore less expensive, + but will not produce as reliable results when the training dataset is not sufficiently large.. + +`TrainValidationSplit` takes an `Estimator`, a set of `ParamMap`s provided in the `estimatorParamMaps` parameter, +and an `Evaluator`. +It begins by splitting the dataset into two parts using `trainRatio` parameter +which are used as separate training and test datasets. For example with `$trainRatio=0.75$` (default), +`TrainValidationSplit` will generate a training and test dataset pair where 75% of the data is used for training and 25% for validation. +Similar to `CrossValidator`, `TrainValidationSplit` also iterates through the set of `ParamMap`s. +For each combination of parameters, it trains the given `Estimator` and evaluates it using the given `Evaluator`. +The `ParamMap` which produces the best evaluation metric is selected as the best option. +`TrainValidationSplit` finally fits the `Estimator` using the best `ParamMap` and the entire dataset. + +
+ +
+{% highlight scala %} +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} +import org.apache.spark.mllib.util.MLUtils + +// Prepare training and test data. +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) + +val lr = new LinearRegression() + +// We use a ParamGridBuilder to construct a grid of parameters to search over. +// TrainValidationSplit will try all combinations of values and determine best model using +// the evaluator. +val paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.01)) + .addGrid(lr.fitIntercept, Array(true, false)) + .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)) + .build() + +// In this case the estimator is simply the linear regression. +// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +val trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator) + .setEstimatorParamMaps(paramGrid) + +// 80% of the data will be used for training and the remaining 20% for validation. +trainValidationSplit.setTrainRatio(0.8) + +// Run train validation split, and choose the best set of parameters. +val model = trainValidationSplit.fit(training) + +// Make predictions on test data. model is the model with combination of parameters +// that performed best. +model.transform(test) + .select("features", "label", "prediction") + .show() + +{% endhighlight %} +
+ +
+{% highlight java %} +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.regression.LinearRegression; +import org.apache.spark.ml.tuning.*; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.DataFrame; + +DataFrame data = jsql.createDataFrame( + MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"), + LabeledPoint.class); + +// Prepare training and test data. +DataFrame[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345); +DataFrame training = splits[0]; +DataFrame test = splits[1]; + +LinearRegression lr = new LinearRegression(); + +// We use a ParamGridBuilder to construct a grid of parameters to search over. +// TrainValidationSplit will try all combinations of values and determine best model using +// the evaluator. +ParamMap[] paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam(), new double[] {0.1, 0.01}) + .addGrid(lr.fitIntercept()) + .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0}) + .build(); + +// In this case the estimator is simply the linear regression. +// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +TrainValidationSplit trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator()) + .setEstimatorParamMaps(paramGrid); + +// 80% of the data will be used for training and the remaining 20% for validation. +trainValidationSplit.setTrainRatio(0.8); + +// Run train validation split, and choose the best set of parameters. +TrainValidationSplitModel model = trainValidationSplit.fit(training); + +// Make predictions on test data. model is the model with combination of parameters +// that performed best. +model.transform(test) + .select("features", "label", "prediction") + .show(); + +{% endhighlight %} +
+ +
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java new file mode 100644 index 0000000000000..23f834ab4332b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.regression.LinearRegression; +import org.apache.spark.ml.tuning.*; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +/** + * A simple example demonstrating model selection using TrainValidationSplit. + * + * The example is based on {@link org.apache.spark.examples.ml.JavaSimpleParamsExample} + * using linear regression. + * + * Run with + * {{{ + * bin/run-example ml.JavaTrainValidationSplitExample + * }}} + */ +public class JavaTrainValidationSplitExample { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaTrainValidationSplitExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + DataFrame data = jsql.createDataFrame( + MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"), + LabeledPoint.class); + + // Prepare training and test data. + DataFrame[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345); + DataFrame training = splits[0]; + DataFrame test = splits[1]; + + LinearRegression lr = new LinearRegression(); + + // We use a ParamGridBuilder to construct a grid of parameters to search over. + // TrainValidationSplit will try all combinations of values and determine best model using + // the evaluator. + ParamMap[] paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam(), new double[] {0.1, 0.01}) + .addGrid(lr.fitIntercept()) + .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0}) + .build(); + + // In this case the estimator is simply the linear regression. + // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + TrainValidationSplit trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator()) + .setEstimatorParamMaps(paramGrid); + + // 80% of the data will be used for training and the remaining 20% for validation. + trainValidationSplit.setTrainRatio(0.8); + + // Run train validation split, and choose the best set of parameters. + TrainValidationSplitModel model = trainValidationSplit.fit(training); + + // Make predictions on test data. model is the model with combination of parameters + // that performed best. + model.transform(test) + .select("features", "label", "prediction") + .show(); + + jsc.stop(); + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala new file mode 100644 index 0000000000000..1abdf219b1c00 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} + +/** + * A simple example demonstrating model selection using TrainValidationSplit. + * + * The example is based on [[SimpleParamsExample]] using linear regression. + * Run with + * {{{ + * bin/run-example ml.TrainValidationSplitExample + * }}} + */ +object TrainValidationSplitExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("TrainValidationSplitExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Prepare training and test data. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) + + val lr = new LinearRegression() + + // We use a ParamGridBuilder to construct a grid of parameters to search over. + // TrainValidationSplit will try all combinations of values and determine best model using + // the evaluator. + val paramGrid = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.1, 0.01)) + .addGrid(lr.fitIntercept, Array(true, false)) + .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)) + .build() + + // In this case the estimator is simply the linear regression. + // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + val trainValidationSplit = new TrainValidationSplit() + .setEstimator(lr) + .setEvaluator(new RegressionEvaluator) + .setEstimatorParamMaps(paramGrid) + + // 80% of the data will be used for training and the remaining 20% for validation. + trainValidationSplit.setTrainRatio(0.8) + + // Run train validation split, and choose the best set of parameters. + val model = trainValidationSplit.fit(training) + + // Make predictions on test data. model is the model with combination of parameters + // that performed best. + model.transform(test) + .select("features", "label", "prediction") + .show() + + sc.stop() + } +} From 5369be806848f43cb87c76504258c4e7de930c90 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Sat, 29 Aug 2015 13:20:22 -0700 Subject: [PATCH 137/802] [SPARK-10350] [DOC] [SQL] Removed duplicated option description from SQL guide Author: GuoQiang Li Closes #8520 from witgo/SPARK-10350. --- docs/sql-programming-guide.md | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index e8eb88488ee24..6a1b0fbfa1eb3 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1405,16 +1405,6 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext`

- - - - -
Property NameDefaultMeaning
spark.akka.failure-detector.threshold300.0 - This is set to a larger value to disable failure detector that comes inbuilt akka. It can be - enabled again, if you plan to use this feature (Not recommended). This maps to akka's - `akka.remote.transport-failure-detector.threshold`. Tune this in combination of - `spark.akka.heartbeat.pauses` and `spark.akka.heartbeat.interval` if you need to. -
spark.akka.frameSize 128
spark.sql.parquet.mergeSchemafalse -

- When true, the Parquet data source merges schemas collected from all data files, otherwise the - schema is picked from the summary file or a random data file if no summary file is available. -

-
## JSON Datasets From 24ffa85c002a095ffb270175ec838995d3ed5469 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 29 Aug 2015 13:24:32 -0700 Subject: [PATCH 138/802] [SPARK-10289] [SQL] A direct write API for testing Parquet This PR introduces a direct write API for testing Parquet. It's a DSL flavored version of the [`writeDirect` method] [1] comes with parquet-avro testing code. With this API, it's much easier to construct arbitrary Parquet structures. It's especially useful when adding regression tests for various compatibility corner cases. Sample usage of this API can be found in the new test case added in `ParquetThriftCompatibilitySuite`. [1]: https://github.com/apache/parquet-mr/blob/apache-parquet-1.8.1/parquet-avro/src/test/java/org/apache/parquet/avro/TestArrayCompatibility.java#L945-L972 Author: Cheng Lian Closes #8454 from liancheng/spark-10289/parquet-testing-direct-write-api. --- .../parquet/ParquetCompatibilityTest.scala | 84 +++++++++++++-- .../ParquetThriftCompatibilitySuite.scala | 100 +++++++++++++++--- 2 files changed, 160 insertions(+), 24 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala index df68432faeeb3..91f3ce4d34c8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -17,11 +17,15 @@ package org.apache.spark.sql.execution.datasources.parquet -import scala.collection.JavaConverters._ +import scala.collection.JavaConverters.{collectionAsScalaIterableConverter, mapAsJavaMapConverter, seqAsJavaListConverter} +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, PathFilter} -import org.apache.parquet.hadoop.ParquetFileReader -import org.apache.parquet.schema.MessageType +import org.apache.parquet.hadoop.api.WriteSupport +import org.apache.parquet.hadoop.api.WriteSupport.WriteContext +import org.apache.parquet.hadoop.{ParquetFileReader, ParquetWriter} +import org.apache.parquet.io.api.RecordConsumer +import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.sql.QueryTest @@ -38,11 +42,10 @@ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with Parq val fs = fsPath.getFileSystem(configuration) val parquetFiles = fs.listStatus(fsPath, new PathFilter { override def accept(path: Path): Boolean = pathFilter(path) - }).toSeq + }).toSeq.asJava - val footers = - ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles.asJava, true) - footers.iterator().next().getParquetMetadata.getFileMetaData.getSchema + val footers = ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles, true) + footers.asScala.head.getParquetMetadata.getFileMetaData.getSchema } protected def logParquetSchema(path: String): Unit = { @@ -53,8 +56,69 @@ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with Parq } } -object ParquetCompatibilityTest { - def makeNullable[T <: AnyRef](i: Int)(f: => T): T = { - if (i % 3 == 0) null.asInstanceOf[T] else f +private[sql] object ParquetCompatibilityTest { + implicit class RecordConsumerDSL(consumer: RecordConsumer) { + def message(f: => Unit): Unit = { + consumer.startMessage() + f + consumer.endMessage() + } + + def group(f: => Unit): Unit = { + consumer.startGroup() + f + consumer.endGroup() + } + + def field(name: String, index: Int)(f: => Unit): Unit = { + consumer.startField(name, index) + f + consumer.endField(name, index) + } + } + + /** + * A testing Parquet [[WriteSupport]] implementation used to write manually constructed Parquet + * records with arbitrary structures. + */ + private class DirectWriteSupport(schema: MessageType, metadata: Map[String, String]) + extends WriteSupport[RecordConsumer => Unit] { + + private var recordConsumer: RecordConsumer = _ + + override def init(configuration: Configuration): WriteContext = { + new WriteContext(schema, metadata.asJava) + } + + override def write(recordWriter: RecordConsumer => Unit): Unit = { + recordWriter.apply(recordConsumer) + } + + override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { + this.recordConsumer = recordConsumer + } + } + + /** + * Writes arbitrary messages conforming to a given `schema` to a Parquet file located by `path`. + * Records are produced by `recordWriters`. + */ + def writeDirect(path: String, schema: String, recordWriters: (RecordConsumer => Unit)*): Unit = { + writeDirect(path, schema, Map.empty[String, String], recordWriters: _*) + } + + /** + * Writes arbitrary messages conforming to a given `schema` to a Parquet file located by `path` + * with given user-defined key-value `metadata`. Records are produced by `recordWriters`. + */ + def writeDirect( + path: String, + schema: String, + metadata: Map[String, String], + recordWriters: (RecordConsumer => Unit)*): Unit = { + val messageType = MessageTypeParser.parseMessageType(schema) + val writeSupport = new DirectWriteSupport(messageType, metadata) + val parquetWriter = new ParquetWriter[RecordConsumer => Unit](new Path(path), writeSupport) + try recordWriters.foreach(parquetWriter.write) finally parquetWriter.close() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala index b789c5a106e56..88a3d878f97fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala @@ -33,11 +33,9 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar """.stripMargin) checkAnswer(sqlContext.read.parquet(parquetFilePath.toString), (0 until 10).map { i => - def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i) - val suits = Array("SPADES", "HEARTS", "DIAMONDS", "CLUBS") - Row( + val nonNullablePrimitiveValues = Seq( i % 2 == 0, i.toByte, (i + 1).toShort, @@ -50,18 +48,15 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar s"val_$i", s"val_$i", // Thrift ENUM values are converted to Parquet binaries containing UTF-8 strings - suits(i % 4), - - nullable(i % 2 == 0: java.lang.Boolean), - nullable(i.toByte: java.lang.Byte), - nullable((i + 1).toShort: java.lang.Short), - nullable(i + 2: Integer), - nullable((i * 10).toLong: java.lang.Long), - nullable(i.toDouble + 0.2d: java.lang.Double), - nullable(s"val_$i"), - nullable(s"val_$i"), - nullable(suits(i % 4)), + suits(i % 4)) + + val nullablePrimitiveValues = if (i % 3 == 0) { + Seq.fill(nonNullablePrimitiveValues.length)(null) + } else { + nonNullablePrimitiveValues + } + val complexValues = Seq( Seq.tabulate(3)(n => s"arr_${i + n}"), // Thrift `SET`s are converted to Parquet `LIST`s Seq(i), @@ -71,6 +66,83 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}") } }.toMap) + + Row(nonNullablePrimitiveValues ++ nullablePrimitiveValues ++ complexValues: _*) }) } + + test("SPARK-10136 list of primitive list") { + withTempPath { dir => + val path = dir.getCanonicalPath + + // This Parquet schema is translated from the following Thrift schema: + // + // struct ListOfPrimitiveList { + // 1: list> f; + // } + val schema = + s"""message ListOfPrimitiveList { + | required group f (LIST) { + | repeated group f_tuple (LIST) { + | repeated int32 f_tuple_tuple; + | } + | } + |} + """.stripMargin + + writeDirect(path, schema, { rc => + rc.message { + rc.field("f", 0) { + rc.group { + rc.field("f_tuple", 0) { + rc.group { + rc.field("f_tuple_tuple", 0) { + rc.addInteger(0) + rc.addInteger(1) + } + } + + rc.group { + rc.field("f_tuple_tuple", 0) { + rc.addInteger(2) + rc.addInteger(3) + } + } + } + } + } + } + }, { rc => + rc.message { + rc.field("f", 0) { + rc.group { + rc.field("f_tuple", 0) { + rc.group { + rc.field("f_tuple_tuple", 0) { + rc.addInteger(4) + rc.addInteger(5) + } + } + + rc.group { + rc.field("f_tuple_tuple", 0) { + rc.addInteger(6) + rc.addInteger(7) + } + } + } + } + } + } + }) + + logParquetSchema(path) + + checkAnswer( + sqlContext.read.parquet(path), + Seq( + Row(Seq(Seq(0, 1), Seq(2, 3))), + Row(Seq(Seq(4, 5), Seq(6, 7))))) + } + } } From 5c3d16a9b91bb9a458d3ba141f7bef525cf3d285 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sat, 29 Aug 2015 13:26:01 -0700 Subject: [PATCH 139/802] [SPARK-10344] [SQL] Add tests for extraStrategies Actually using this API requires access to a lot of classes that we might make private by accident. I've added some tests to prevent this. Author: Michael Armbrust Closes #8516 from marmbrus/extraStrategiesTests. --- .../spark/sql/ExtraStrategiesSuite.scala | 67 +++++++++++++++++++ .../spark/sql/test/SharedSQLContext.scala | 2 +- 2 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala new file mode 100644 index 0000000000000..8d2f45d70308b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Literal, GenericInternalRow, Attribute} +import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.{Row, Strategy, QueryTest} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.unsafe.types.UTF8String + +case class FastOperator(output: Seq[Attribute]) extends SparkPlan { + + override protected def doExecute(): RDD[InternalRow] = { + val str = Literal("so fast").value + val row = new GenericInternalRow(Array[Any](str)) + sparkContext.parallelize(Seq(row)) + } + + override def children: Seq[SparkPlan] = Nil +} + +object TestStrategy extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case Project(Seq(attr), _) if attr.name == "a" => + FastOperator(attr.toAttribute :: Nil) :: Nil + case _ => Nil + } +} + +class ExtraStrategiesSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("insert an extraStrategy") { + try { + sqlContext.experimental.extraStrategies = TestStrategy :: Nil + + val df = sqlContext.sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b") + checkAnswer( + df.select("a"), + Row("so fast")) + + checkAnswer( + df.select("a", "b"), + Row("so slow", 1)) + } finally { + sqlContext.experimental.extraStrategies = Nil + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 8a061b6bc690d..d23c6a0732669 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.{ColumnName, SQLContext} /** * Helper trait for SQL test suites where all tests share a single [[TestSQLContext]]. */ -private[sql] trait SharedSQLContext extends SQLTestUtils { +trait SharedSQLContext extends SQLTestUtils { /** * The [[TestSQLContext]] to use for all tests in this suite. From 277148b285748e863f2b9fdf6cf12963977f91ca Mon Sep 17 00:00:00 2001 From: wangwei Date: Sat, 29 Aug 2015 13:29:50 -0700 Subject: [PATCH 140/802] [SPARK-10226] [SQL] Fix exclamation mark issue in SparkSQL When I tested the latest version of spark with exclamation mark, I got some errors. Then I reseted the spark version and found that commit id "a2409d1c8e8ddec04b529ac6f6a12b5993f0eeda" brought the bug. With jline version changing from 0.9.94 to 2.12 after this commit, exclamation mark would be treated as a special character in ConsoleReader. Author: wangwei Closes #8420 from small-wang/jline-SPARK-10226. --- .../apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index a29df567983b1..b5073961a1c84 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -171,6 +171,7 @@ private[hive] object SparkSQLCLIDriver extends Logging { val reader = new ConsoleReader() reader.setBellEnabled(false) + reader.setExpandEvents(false) // reader.setDebug(new PrintWriter(new FileWriter("writer.debug", true))) CliDriver.getCommandCompleter.foreach((e) => reader.addCompleter(e)) From 6a6f3c91ee1f63dd464eb03d156d02c1a5887d88 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 29 Aug 2015 13:36:25 -0700 Subject: [PATCH 141/802] [SPARK-10330] Use SparkHadoopUtil TaskAttemptContext reflection methods in more places SparkHadoopUtil contains methods that use reflection to work around TaskAttemptContext binary incompatibilities between Hadoop 1.x and 2.x. We should use these methods in more places. Author: Josh Rosen Closes #8499 from JoshRosen/use-hadoop-reflection-in-more-places. --- .../sql/execution/datasources/WriterContainer.scala | 10 +++++++--- .../sql/execution/datasources/json/JSONRelation.scala | 7 +++++-- .../datasources/parquet/ParquetRelation.scala | 7 +++++-- .../org/apache/spark/sql/hive/orc/OrcRelation.scala | 9 ++++++--- .../apache/spark/sql/sources/SimpleTextRelation.scala | 7 +++++-- 5 files changed, 28 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 78f48a5cd72c7..879fd69863211 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter} import org.apache.spark._ +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.sql._ @@ -145,7 +146,8 @@ private[sql] abstract class BaseWriterContainer( "because spark.speculation is configured to be true.") defaultOutputCommitter } else { - val committerClass = context.getConfiguration.getClass( + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val committerClass = configuration.getClass( SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) Option(committerClass).map { clazz => @@ -227,7 +229,8 @@ private[sql] class DefaultWriterContainer( def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { executorSideSetup(taskContext) - taskAttemptContext.getConfiguration.set("spark.sql.sources.output.path", outputPath) + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(taskAttemptContext) + configuration.set("spark.sql.sources.output.path", outputPath) val writer = outputWriterFactory.newInstance(getWorkPath, dataSchema, taskAttemptContext) writer.initConverter(dataSchema) @@ -395,7 +398,8 @@ private[sql] class DynamicPartitionWriterContainer( def newOutputWriter(key: InternalRow): OutputWriter = { val partitionPath = getPartitionString(key).getString(0) val path = new Path(getWorkPath, partitionPath) - taskAttemptContext.getConfiguration.set( + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(taskAttemptContext) + configuration.set( "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) val newWriter = outputWriterFactory.newInstance(path.toString, dataSchema, taskAttemptContext) newWriter.initConverter(dataSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index ab8ca5f748f24..7a49157d9e72c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.spark.Logging import org.apache.spark.broadcast.Broadcast +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -169,8 +170,10 @@ private[json] class JsonOutputWriter( private val recordWriter: RecordWriter[NullWritable, Text] = { new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") - val split = context.getTaskAttemptID.getTaskID.getId + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val split = taskAttemptId.getTaskID.getId new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") } }.getRecordWriter(context) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 64982f37cf872..c6bbc392cad4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -40,6 +40,7 @@ import org.apache.parquet.{Log => ApacheParquetLog} import org.slf4j.bridge.SLF4JBridgeHandler import org.apache.spark.broadcast.Broadcast +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow @@ -81,8 +82,10 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext // `FileOutputCommitter.getWorkPath()`, which points to the base directory of all // partitions in the case of dynamic partitioning. override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") - val split = context.getTaskAttemptID.getTaskID.getId + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val split = taskAttemptId.getTaskID.getId new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 1cff5cf9c3543..4eeca9aec12bd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -34,6 +34,7 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.Logging +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow @@ -77,7 +78,8 @@ private[orc] class OrcOutputWriter( }.mkString(":")) val serde = new OrcSerde - serde.initialize(context.getConfiguration, table) + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + serde.initialize(configuration, table) serde } @@ -109,9 +111,10 @@ private[orc] class OrcOutputWriter( private lazy val recordWriter: RecordWriter[NullWritable, Writable] = { recordWriterInstantiated = true - val conf = context.getConfiguration + val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(context) val uniqueWriteJobId = conf.get("spark.sql.sources.writeJobUUID") - val partition = context.getTaskAttemptID.getTaskID.getId + val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val partition = taskAttemptId.getTaskID.getId val filename = f"part-r-$partition%05d-$uniqueWriteJobId.orc" new OrcOutputFormat().getRecordWriter( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index e8141923a9b5c..527ca7a81cad8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputForma import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.spark.rdd.RDD +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types.{DataType, StructType} @@ -53,8 +54,10 @@ class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullW numberFormat.setGroupingUsed(false) override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") - val split = context.getTaskAttemptID.getTaskID.getId + val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val split = taskAttemptId.getTaskID.getId val name = FileOutputFormat.getOutputName(context) new Path(outputFile, s"$name-${numberFormat.format(split)}-$uniqueWriteJobId") } From 097a7e36e0bf7290b1879331375bacc905583bd3 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sat, 29 Aug 2015 16:39:40 -0700 Subject: [PATCH 142/802] [SPARK-10339] [SPARK-10334] [SPARK-10301] [SQL] Partitioned table scan can OOM driver and throw a better error message when users need to enable parquet schema merging This fixes the problem that scanning partitioned table causes driver have a high memory pressure and takes down the cluster. Also, with this fix, we will be able to correctly show the query plan of a query consuming partitioned tables. https://issues.apache.org/jira/browse/SPARK-10339 https://issues.apache.org/jira/browse/SPARK-10334 Finally, this PR squeeze in a "quick fix" for SPARK-10301. It is not a real fix, but it just throw a better error message to let user know what to do. Author: Yin Huai Closes #8515 from yhuai/partitionedTableScan. --- .../datasources/DataSourceStrategy.scala | 85 ++++++++++--------- .../parquet/CatalystRowConverter.scala | 7 ++ .../ParquetHadoopFsRelationSuite.scala | 15 +++- 3 files changed, 65 insertions(+), 42 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 6c1ef6a6df887..c58213155daa8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.sql.{SaveMode, Strategy, execution, sources, _} @@ -121,7 +122,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { projections: Seq[NamedExpression], filters: Seq[Expression], partitionColumns: StructType, - partitions: Array[Partition]) = { + partitions: Array[Partition]): SparkPlan = { val relation = logicalRelation.relation.asInstanceOf[HadoopFsRelation] // Because we are creating one RDD per partition, we need to have a shared HadoopConf. @@ -130,49 +131,51 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val confBroadcast = relation.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf)) - // Builds RDD[Row]s for each selected partition. - val perPartitionRows = partitions.map { case Partition(partitionValues, dir) => - // The table scan operator (PhysicalRDD) which retrieves required columns from data files. - // Notice that the schema of data files, represented by `relation.dataSchema`, may contain - // some partition column(s). - val scan = - pruneFilterProject( - logicalRelation, - projections, - filters, - (columns: Seq[Attribute], filters) => { - val partitionColNames = partitionColumns.fieldNames - - // Don't scan any partition columns to save I/O. Here we are being optimistic and - // assuming partition columns data stored in data files are always consistent with those - // partition values encoded in partition directory paths. - val needed = columns.filterNot(a => partitionColNames.contains(a.name)) - val dataRows = - relation.buildScan(needed.map(_.name).toArray, filters, Array(dir), confBroadcast) - - // Merges data values with partition values. - mergeWithPartitionValues( - relation.schema, - columns.map(_.name).toArray, - partitionColNames, - partitionValues, - toCatalystRDD(logicalRelation, needed, dataRows)) - }) - - scan.execute() - } + // Now, we create a scan builder, which will be used by pruneFilterProject. This scan builder + // will union all partitions and attach partition values if needed. + val scanBuilder = { + (columns: Seq[Attribute], filters: Array[Filter]) => { + // Builds RDD[Row]s for each selected partition. + val perPartitionRows = partitions.map { case Partition(partitionValues, dir) => + val partitionColNames = partitionColumns.fieldNames + + // Don't scan any partition columns to save I/O. Here we are being optimistic and + // assuming partition columns data stored in data files are always consistent with those + // partition values encoded in partition directory paths. + val needed = columns.filterNot(a => partitionColNames.contains(a.name)) + val dataRows = + relation.buildScan(needed.map(_.name).toArray, filters, Array(dir), confBroadcast) + + // Merges data values with partition values. + mergeWithPartitionValues( + relation.schema, + columns.map(_.name).toArray, + partitionColNames, + partitionValues, + toCatalystRDD(logicalRelation, needed, dataRows)) + } + + val unionedRows = + if (perPartitionRows.length == 0) { + relation.sqlContext.emptyResult + } else { + new UnionRDD(relation.sqlContext.sparkContext, perPartitionRows) + } - val unionedRows = - if (perPartitionRows.length == 0) { - relation.sqlContext.emptyResult - } else { - new UnionRDD(relation.sqlContext.sparkContext, perPartitionRows) + unionedRows } + } + + // Create the scan operator. If needed, add Filter and/or Project on top of the scan. + // The added Filter/Project is on top of the unioned RDD. We do not want to create + // one Filter/Project for every partition. + val sparkPlan = pruneFilterProject( + logicalRelation, + projections, + filters, + scanBuilder) - execution.PhysicalRDD.createFromDataSource( - projections.map(_.toAttribute), - unionedRows, - logicalRelation.relation) + sparkPlan } // TODO: refactor this thing. It is very complicated because it does projection internally. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index f682ca0d8ff4f..fe13dfbbed385 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -196,6 +196,13 @@ private[parquet] class CatalystRowConverter( } } + if (paddedParquetFields.length != catalystType.length) { + throw new UnsupportedOperationException( + "A Parquet file's schema has different number of fields with the table schema. " + + "Please enable schema merging by setting \"mergeSchema\" to true when load " + + "a Parquet dataset or set spark.sql.parquet.mergeSchema to true in SQLConf.") + } + paddedParquetFields.zip(catalystType).zipWithIndex.map { case ((parquetFieldType, catalystField), ordinal) => // Converted field value should be set to the `ordinal`-th cell of `currentRow` diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index cb4cedddbfddd..06dadbb5feab0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -23,7 +23,7 @@ import com.google.common.io.Files import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.{execution, AnalysisException, SaveMode} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -136,4 +136,17 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { assert(fs.exists(commonSummaryPath)) } } + + test("SPARK-10334 Projections and filters should be kept in physical plan") { + withTempPath { dir => + val path = dir.getCanonicalPath + + sqlContext.range(2).select('id as 'a, 'id as 'b).write.partitionBy("b").parquet(path) + val df = sqlContext.read.parquet(path).filter('a === 0).select('b) + val physicalPlan = df.queryExecution.executedPlan + + assert(physicalPlan.collect { case p: execution.Project => p }.length === 1) + assert(physicalPlan.collect { case p: execution.Filter => p }.length === 1) + } + } } From 13f5f8ec97c6886346641b73bd99004e0d70836c Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sat, 29 Aug 2015 18:10:44 -0700 Subject: [PATCH 143/802] [SPARK-9986] [SPARK-9991] [SPARK-9993] [SQL] Create a simple test framework for local operators This PR includes the following changes: - Add `LocalNodeTest` for local operator tests and add unit tests for FilterNode and ProjectNode. - Add `LimitNode` and `UnionNode` and their unit tests to show how to use `LocalNodeTest`. (SPARK-9991, SPARK-9993) Author: zsxwing Closes #8464 from zsxwing/local-execution. --- .../sql/execution/local/FilterNode.scala | 6 +- .../spark/sql/execution/local/LimitNode.scala | 45 ++++++ .../spark/sql/execution/local/LocalNode.scala | 13 +- .../sql/execution/local/ProjectNode.scala | 4 +- .../sql/execution/local/SeqScanNode.scala | 2 +- .../spark/sql/execution/local/UnionNode.scala | 72 +++++++++ .../spark/sql/execution/SparkPlanTest.scala | 46 +----- .../sql/execution/local/FilterNodeSuite.scala | 41 +++++ .../sql/execution/local/LimitNodeSuite.scala | 39 +++++ .../sql/execution/local/LocalNodeTest.scala | 146 ++++++++++++++++++ .../execution/local/ProjectNodeSuite.scala | 44 ++++++ .../sql/execution/local/UnionNodeSuite.scala | 52 +++++++ .../apache/spark/sql/test/SQLTestData.scala | 8 + .../apache/spark/sql/test/SQLTestUtils.scala | 46 +++++- 14 files changed, 509 insertions(+), 55 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala index a485a1a1d7ae4..81dd37c7da733 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala @@ -35,13 +35,13 @@ case class FilterNode(condition: Expression, child: LocalNode) extends UnaryLoca override def next(): Boolean = { var found = false - while (child.next() && !found) { - found = predicate.apply(child.get()) + while (!found && child.next()) { + found = predicate.apply(child.fetch()) } found } - override def get(): InternalRow = child.get() + override def fetch(): InternalRow = child.fetch() override def close(): Unit = child.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala new file mode 100644 index 0000000000000..fffc52abf6dd5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala @@ -0,0 +1,45 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute + + +case class LimitNode(limit: Int, child: LocalNode) extends UnaryLocalNode { + + private[this] var count = 0 + + override def output: Seq[Attribute] = child.output + + override def open(): Unit = child.open() + + override def close(): Unit = child.close() + + override def fetch(): InternalRow = child.fetch() + + override def next(): Boolean = { + if (count < limit) { + count += 1 + child.next() + } else { + false + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index 341c81438e6d6..1c4469acbf264 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -48,10 +48,10 @@ abstract class LocalNode extends TreeNode[LocalNode] { /** * Returns the current tuple. */ - def get(): InternalRow + def fetch(): InternalRow /** - * Closes the iterator and releases all resources. + * Closes the iterator and releases all resources. It should be idempotent. * * Implementations of this must also call the `close()` function of its children. */ @@ -64,10 +64,13 @@ abstract class LocalNode extends TreeNode[LocalNode] { val converter = CatalystTypeConverters.createToScalaConverter(StructType.fromAttributes(output)) val result = new scala.collection.mutable.ArrayBuffer[Row] open() - while (next()) { - result += converter.apply(get()).asInstanceOf[Row] + try { + while (next()) { + result += converter.apply(fetch()).asInstanceOf[Row] + } + } finally { + close() } - close() result } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala index e574d1473cdcb..9b8a4fe493026 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala @@ -34,8 +34,8 @@ case class ProjectNode(projectList: Seq[NamedExpression], child: LocalNode) exte override def next(): Boolean = child.next() - override def get(): InternalRow = { - project.apply(child.get()) + override def fetch(): InternalRow = { + project.apply(child.fetch()) } override def close(): Unit = child.close() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala index 994de8afa9a02..242cb66e07b7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala @@ -41,7 +41,7 @@ case class SeqScanNode(output: Seq[Attribute], data: Seq[InternalRow]) extends L } } - override def get(): InternalRow = currentRow + override def fetch(): InternalRow = currentRow override def close(): Unit = { // Do nothing diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala new file mode 100644 index 0000000000000..ba4aa7671aebd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala @@ -0,0 +1,72 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class UnionNode(children: Seq[LocalNode]) extends LocalNode { + + override def output: Seq[Attribute] = children.head.output + + private[this] var currentChild: LocalNode = _ + + private[this] var nextChildIndex: Int = _ + + override def open(): Unit = { + currentChild = children.head + currentChild.open() + nextChildIndex = 1 + } + + private def advanceToNextChild(): Boolean = { + var found = false + var exit = false + while (!exit && !found) { + if (currentChild != null) { + currentChild.close() + } + if (nextChildIndex >= children.size) { + found = false + exit = true + } else { + currentChild = children(nextChildIndex) + nextChildIndex += 1 + currentChild.open() + found = currentChild.next() + } + } + found + } + + override def close(): Unit = { + if (currentChild != null) { + currentChild.close() + } + } + + override def fetch(): InternalRow = currentChild.fetch() + + override def next(): Boolean = { + if (currentChild.next()) { + true + } else { + advanceToNextChild() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 3a87f374d94b0..5ab8f44faebf6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -24,7 +24,7 @@ import scala.util.control.NonFatal import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.test.SQLTestUtils /** * Base class for writing tests for individual physical operators. For an example of how this @@ -184,7 +184,7 @@ object SparkPlanTest { return Some(errorMessage) } - compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage => + SQLTestUtils.compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage => s""" | Results do not match. | Actual result Spark plan: @@ -229,7 +229,7 @@ object SparkPlanTest { return Some(errorMessage) } - compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage => + SQLTestUtils.compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage => s""" | Results do not match for Spark plan: | $outputPlan @@ -238,46 +238,6 @@ object SparkPlanTest { } } - private def compareAnswers( - sparkAnswer: Seq[Row], - expectedAnswer: Seq[Row], - sort: Boolean): Option[String] = { - def prepareAnswer(answer: Seq[Row]): Seq[Row] = { - // Converts data to types that we can do equality comparison using Scala collections. - // For BigDecimal type, the Scala type has a better definition of equality test (similar to - // Java's java.math.BigDecimal.compareTo). - // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for - // equality test. - // This function is copied from Catalyst's QueryTest - val converted: Seq[Row] = answer.map { s => - Row.fromSeq(s.toSeq.map { - case d: java.math.BigDecimal => BigDecimal(d) - case b: Array[Byte] => b.toSeq - case o => o - }) - } - if (sort) { - converted.sortBy(_.toString()) - } else { - converted - } - } - if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { - val errorMessage = - s""" - | == Results == - | ${sideBySide( - s"== Expected Answer - ${expectedAnswer.size} ==" +: - prepareAnswer(expectedAnswer).map(_.toString()), - s"== Actual Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} - """.stripMargin - Some(errorMessage) - } else { - None - } - } - private def executePlan(outputPlan: SparkPlan, _sqlContext: SQLContext): Seq[Row] = { // A very simple resolver to make writing tests easier. In contrast to the real resolver // this is always case sensitive and does not try to handle scoping or complex type resolution. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala new file mode 100644 index 0000000000000..07209f3779248 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala @@ -0,0 +1,41 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.test.SharedSQLContext + +class FilterNodeSuite extends LocalNodeTest with SharedSQLContext { + + test("basic") { + val condition = (testData.col("key") % 2) === 0 + checkAnswer( + testData, + node => FilterNode(condition.expr, node), + testData.filter(condition).collect() + ) + } + + test("empty") { + val condition = (emptyTestData.col("key") % 2) === 0 + checkAnswer( + emptyTestData, + node => FilterNode(condition.expr, node), + emptyTestData.filter(condition).collect() + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala new file mode 100644 index 0000000000000..523c02f4a6014 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala @@ -0,0 +1,39 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.test.SharedSQLContext + +class LimitNodeSuite extends LocalNodeTest with SharedSQLContext { + + test("basic") { + checkAnswer( + testData, + node => LimitNode(10, node), + testData.limit(10).collect() + ) + } + + test("empty") { + checkAnswer( + emptyTestData, + node => LimitNode(10, node), + emptyTestData.limit(10).collect() + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala new file mode 100644 index 0000000000000..95f06081bd0a8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -0,0 +1,146 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import scala.util.control.NonFatal + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.test.SQLTestUtils + +class LocalNodeTest extends SparkFunSuite { + + /** + * Runs the LocalNode and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param nodeFunction a function which accepts the input LocalNode and uses it to instantiate + * the local physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. + */ + protected def checkAnswer( + input: DataFrame, + nodeFunction: LocalNode => LocalNode, + expectedAnswer: Seq[Row], + sortAnswers: Boolean = true): Unit = { + doCheckAnswer( + input :: Nil, + nodes => nodeFunction(nodes.head), + expectedAnswer, + sortAnswers) + } + + /** + * Runs the LocalNode and makes sure the answer matches the expected result. + * @param left the left input data to be used. + * @param right the right input data to be used. + * @param nodeFunction a function which accepts the input LocalNode and uses it to instantiate + * the local physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. + */ + protected def checkAnswer2( + left: DataFrame, + right: DataFrame, + nodeFunction: (LocalNode, LocalNode) => LocalNode, + expectedAnswer: Seq[Row], + sortAnswers: Boolean = true): Unit = { + doCheckAnswer( + left :: right :: Nil, + nodes => nodeFunction(nodes(0), nodes(1)), + expectedAnswer, + sortAnswers) + } + + /** + * Runs the `LocalNode`s and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param nodeFunction a function which accepts a sequence of input `LocalNode`s and uses them to + * instantiate the local physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. + */ + protected def doCheckAnswer( + input: Seq[DataFrame], + nodeFunction: Seq[LocalNode] => LocalNode, + expectedAnswer: Seq[Row], + sortAnswers: Boolean = true): Unit = { + LocalNodeTest.checkAnswer( + input.map(dataFrameToSeqScanNode), nodeFunction, expectedAnswer, sortAnswers) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + protected def dataFrameToSeqScanNode(df: DataFrame): SeqScanNode = { + new SeqScanNode( + df.queryExecution.sparkPlan.output, + df.queryExecution.toRdd.map(_.copy()).collect()) + } + +} + +/** + * Helper methods for writing tests of individual local physical operators. + */ +object LocalNodeTest { + + /** + * Runs the `LocalNode`s and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param nodeFunction a function which accepts the input `LocalNode`s and uses them to + * instantiate the local physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. + */ + def checkAnswer( + input: Seq[SeqScanNode], + nodeFunction: Seq[LocalNode] => LocalNode, + expectedAnswer: Seq[Row], + sortAnswers: Boolean): Option[String] = { + + val outputNode = nodeFunction(input) + + val outputResult: Seq[Row] = try { + outputNode.collect() + } catch { + case NonFatal(e) => + val errorMessage = + s""" + | Exception thrown while executing local plan: + | $outputNode + | == Exception == + | $e + | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) + } + + SQLTestUtils.compareAnswers(outputResult, expectedAnswer, sortAnswers).map { errorMessage => + s""" + | Results do not match for local plan: + | $outputNode + | $errorMessage + """.stripMargin + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala new file mode 100644 index 0000000000000..ffcf092e2c66a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala @@ -0,0 +1,44 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.test.SharedSQLContext + +class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext { + + test("basic") { + val output = testData.queryExecution.sparkPlan.output + val columns = Seq(output(1), output(0)) + checkAnswer( + testData, + node => ProjectNode(columns, node), + testData.select("value", "key").collect() + ) + } + + test("empty") { + val output = emptyTestData.queryExecution.sparkPlan.output + val columns = Seq(output(1), output(0)) + checkAnswer( + emptyTestData, + node => ProjectNode(columns, node), + emptyTestData.select("value", "key").collect() + ) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala new file mode 100644 index 0000000000000..34670287c3e1d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala @@ -0,0 +1,52 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.test.SharedSQLContext + +class UnionNodeSuite extends LocalNodeTest with SharedSQLContext { + + test("basic") { + checkAnswer2( + testData, + testData, + (node1, node2) => UnionNode(Seq(node1, node2)), + testData.unionAll(testData).collect() + ) + } + + test("empty") { + checkAnswer2( + emptyTestData, + emptyTestData, + (node1, node2) => UnionNode(Seq(node1, node2)), + emptyTestData.unionAll(emptyTestData).collect() + ) + } + + test("complicated union") { + val dfs = Seq(testData, emptyTestData, emptyTestData, testData, testData, emptyTestData, + emptyTestData, emptyTestData, testData, emptyTestData) + doCheckAnswer( + dfs, + nodes => UnionNode(nodes), + dfs.reduce(_.unionAll(_)).collect() + ) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 1374a97476ca1..3fc02df954e23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -36,6 +36,13 @@ private[sql] trait SQLTestData { self => // Note: all test data should be lazy because the SQLContext is not set up yet. + protected lazy val emptyTestData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + Seq.empty[Int].map(i => TestData(i, i.toString))).toDF() + df.registerTempTable("emptyTestData") + df + } + protected lazy val testData: DataFrame = { val df = _sqlContext.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() @@ -240,6 +247,7 @@ private[sql] trait SQLTestData { self => */ def loadTestData(): Unit = { assert(_sqlContext != null, "attempted to initialize test data before SQLContext.") + emptyTestData testData testData2 testData3 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index cdd691e035897..dc08306ad9cb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -27,8 +27,9 @@ import org.apache.hadoop.conf.Configuration import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits} +import org.apache.spark.sql.{DataFrame, Row, SQLContext, SQLImplicits} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.util.Utils /** @@ -179,3 +180,46 @@ private[sql] trait SQLTestUtils DataFrame(_sqlContext, plan) } } + +private[sql] object SQLTestUtils { + + def compareAnswers( + sparkAnswer: Seq[Row], + expectedAnswer: Seq[Row], + sort: Boolean): Option[String] = { + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + // This function is copied from Catalyst's QueryTest + val converted: Seq[Row] = answer.map { s => + Row.fromSeq(s.toSeq.map { + case d: java.math.BigDecimal => BigDecimal(d) + case b: Array[Byte] => b.toSeq + case o => o + }) + } + if (sort) { + converted.sortBy(_.toString()) + } else { + converted + } + } + if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { + val errorMessage = + s""" + | == Results == + | ${sideBySide( + s"== Expected Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString()), + s"== Actual Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} + """.stripMargin + Some(errorMessage) + } else { + None + } + } +} From 905fbe498bdd29116468628e6a2a553c1fd57165 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sat, 29 Aug 2015 23:26:23 -0700 Subject: [PATCH 144/802] [SPARK-10348] [MLLIB] updates ml-guide * replace `ML Dataset` by `DataFrame` to unify the abstraction * ML algorithms -> pipeline components to describe the main concept * remove Scala API doc links from the main guide * `Section Title` -> `Section tile` to be consistent with other section titles in MLlib guide * modified lines break at 100 chars or periods jkbradley feynmanliang Author: Xiangrui Meng Closes #8517 from mengxr/SPARK-10348. --- docs/ml-guide.md | 118 +++++++++++++++++++++++++++----------------- docs/mllib-guide.md | 12 ++--- 2 files changed, 78 insertions(+), 52 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index a92a285f3af85..4ba07542bfb40 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -24,61 +24,74 @@ title: Spark ML Programming Guide The `spark.ml` package aims to provide a uniform set of high-level APIs built on top of [DataFrames](sql-programming-guide.html#dataframes) that help users create and tune practical machine learning pipelines. -See the [Algorithm Guides section](#algorithm-guides) below for guides on sub-packages of +See the [algorithm guides](#algorithm-guides) section below for guides on sub-packages of `spark.ml`, including feature transformers unique to the Pipelines API, ensembles, and more. -**Table of Contents** +**Table of contents** * This will become a table of contents (this text will be scraped). {:toc} -# Main Concepts +# Main concepts -Spark ML standardizes APIs for machine learning algorithms to make it easier to combine multiple algorithms into a single pipeline, or workflow. This section covers the key concepts introduced by the Spark ML API. +Spark ML standardizes APIs for machine learning algorithms to make it easier to combine multiple +algorithms into a single pipeline, or workflow. +This section covers the key concepts introduced by the Spark ML API, where the pipeline concept is +mostly inspired by the [scikit-learn](http://scikit-learn.org/) project. -* **[ML Dataset](ml-guide.html#ml-dataset)**: Spark ML uses the [`DataFrame`](api/scala/index.html#org.apache.spark.sql.DataFrame) from Spark SQL as a dataset which can hold a variety of data types. -E.g., a dataset could have different columns storing text, feature vectors, true labels, and predictions. +* **[`DataFrame`](ml-guide.html#dataframe)**: Spark ML uses `DataFrame` from Spark SQL as an ML + dataset, which can hold a variety of data types. + E.g., a `DataFrame` could have different columns storing text, feature vectors, true labels, and predictions. * **[`Transformer`](ml-guide.html#transformers)**: A `Transformer` is an algorithm which can transform one `DataFrame` into another `DataFrame`. -E.g., an ML model is a `Transformer` which transforms an RDD with features into an RDD with predictions. +E.g., an ML model is a `Transformer` which transforms `DataFrame` with features into a `DataFrame` with predictions. * **[`Estimator`](ml-guide.html#estimators)**: An `Estimator` is an algorithm which can be fit on a `DataFrame` to produce a `Transformer`. -E.g., a learning algorithm is an `Estimator` which trains on a dataset and produces a model. +E.g., a learning algorithm is an `Estimator` which trains on a `DataFrame` and produces a model. * **[`Pipeline`](ml-guide.html#pipeline)**: A `Pipeline` chains multiple `Transformer`s and `Estimator`s together to specify an ML workflow. -* **[`Param`](ml-guide.html#parameters)**: All `Transformer`s and `Estimator`s now share a common API for specifying parameters. +* **[`Parameter`](ml-guide.html#parameters)**: All `Transformer`s and `Estimator`s now share a common API for specifying parameters. -## ML Dataset +## DataFrame Machine learning can be applied to a wide variety of data types, such as vectors, text, images, and structured data. -Spark ML adopts the [`DataFrame`](api/scala/index.html#org.apache.spark.sql.DataFrame) from Spark SQL in order to support a variety of data types under a unified Dataset concept. +Spark ML adopts the `DataFrame` from Spark SQL in order to support a variety of data types. `DataFrame` supports many basic and structured types; see the [Spark SQL datatype reference](sql-programming-guide.html#spark-sql-datatype-reference) for a list of supported types. -In addition to the types listed in the Spark SQL guide, `DataFrame` can use ML [`Vector`](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) types. +In addition to the types listed in the Spark SQL guide, `DataFrame` can use ML [`Vector`](mllib-data-types.html#local-vector) types. A `DataFrame` can be created either implicitly or explicitly from a regular `RDD`. See the code examples below and the [Spark SQL programming guide](sql-programming-guide.html) for examples. Columns in a `DataFrame` are named. The code examples below use names such as "text," "features," and "label." -## ML Algorithms +## Pipeline components ### Transformers -A [`Transformer`](api/scala/index.html#org.apache.spark.ml.Transformer) is an abstraction which includes feature transformers and learned models. Technically, a `Transformer` implements a method `transform()` which converts one `DataFrame` into another, generally by appending one or more columns. +A `Transformer` is an abstraction that includes feature transformers and learned models. +Technically, a `Transformer` implements a method `transform()`, which converts one `DataFrame` into +another, generally by appending one or more columns. For example: -* A feature transformer might take a dataset, read a column (e.g., text), convert it into a new column (e.g., feature vectors), append the new column to the dataset, and output the updated dataset. -* A learning model might take a dataset, read the column containing feature vectors, predict the label for each feature vector, append the labels as a new column, and output the updated dataset. +* A feature transformer might take a `DataFrame`, read a column (e.g., text), map it into a new + column (e.g., feature vectors), and output a new `DataFrame` with the mapped column appended. +* A learning model might take a `DataFrame`, read the column containing feature vectors, predict the + label for each feature vector, and output a new `DataFrame` with predicted labels appended as a + column. ### Estimators -An [`Estimator`](api/scala/index.html#org.apache.spark.ml.Estimator) abstracts the concept of a learning algorithm or any algorithm which fits or trains on data. Technically, an `Estimator` implements a method `fit()` which accepts a `DataFrame` and produces a `Transformer`. -For example, a learning algorithm such as `LogisticRegression` is an `Estimator`, and calling `fit()` trains a `LogisticRegressionModel`, which is a `Transformer`. +An `Estimator` abstracts the concept of a learning algorithm or any algorithm that fits or trains on +data. +Technically, an `Estimator` implements a method `fit()`, which accepts a `DataFrame` and produces a +`Model`, which is a `Transformer`. +For example, a learning algorithm such as `LogisticRegression` is an `Estimator`, and calling +`fit()` trains a `LogisticRegressionModel`, which is a `Model` and hence a `Transformer`. -### Properties of ML Algorithms +### Properties of pipeline components -`Transformer`s and `Estimator`s are both stateless. In the future, stateful algorithms may be supported via alternative concepts. +`Transformer.transform()`s and `Estimator.fit()`s are both stateless. In the future, stateful algorithms may be supported via alternative concepts. Each instance of a `Transformer` or `Estimator` has a unique ID, which is useful in specifying parameters (discussed below). @@ -91,15 +104,16 @@ E.g., a simple text document processing workflow might include several stages: * Convert each document's words into a numerical feature vector. * Learn a prediction model using the feature vectors and labels. -Spark ML represents such a workflow as a [`Pipeline`](api/scala/index.html#org.apache.spark.ml.Pipeline), -which consists of a sequence of [`PipelineStage`s](api/scala/index.html#org.apache.spark.ml.PipelineStage) (`Transformer`s and `Estimator`s) to be run in a specific order. We will use this simple workflow as a running example in this section. +Spark ML represents such a workflow as a `Pipeline`, which consists of a sequence of +`PipelineStage`s (`Transformer`s and `Estimator`s) to be run in a specific order. +We will use this simple workflow as a running example in this section. -### How It Works +### How it works A `Pipeline` is specified as a sequence of stages, and each stage is either a `Transformer` or an `Estimator`. -These stages are run in order, and the input dataset is modified as it passes through each stage. -For `Transformer` stages, the `transform()` method is called on the dataset. -For `Estimator` stages, the `fit()` method is called to produce a `Transformer` (which becomes part of the `PipelineModel`, or fitted `Pipeline`), and that `Transformer`'s `transform()` method is called on the dataset. +These stages are run in order, and the input `DataFrame` is transformed as it passes through each stage. +For `Transformer` stages, the `transform()` method is called on the `DataFrame`. +For `Estimator` stages, the `fit()` method is called to produce a `Transformer` (which becomes part of the `PipelineModel`, or fitted `Pipeline`), and that `Transformer`'s `transform()` method is called on the `DataFrame`. We illustrate this for the simple text document workflow. The figure below is for the *training time* usage of a `Pipeline`. @@ -115,14 +129,17 @@ We illustrate this for the simple text document workflow. The figure below is f Above, the top row represents a `Pipeline` with three stages. The first two (`Tokenizer` and `HashingTF`) are `Transformer`s (blue), and the third (`LogisticRegression`) is an `Estimator` (red). The bottom row represents data flowing through the pipeline, where cylinders indicate `DataFrame`s. -The `Pipeline.fit()` method is called on the original dataset which has raw text documents and labels. -The `Tokenizer.transform()` method splits the raw text documents into words, adding a new column with words into the dataset. -The `HashingTF.transform()` method converts the words column into feature vectors, adding a new column with those vectors to the dataset. +The `Pipeline.fit()` method is called on the original `DataFrame`, which has raw text documents and labels. +The `Tokenizer.transform()` method splits the raw text documents into words, adding a new column with words to the `DataFrame`. +The `HashingTF.transform()` method converts the words column into feature vectors, adding a new column with those vectors to the `DataFrame`. Now, since `LogisticRegression` is an `Estimator`, the `Pipeline` first calls `LogisticRegression.fit()` to produce a `LogisticRegressionModel`. -If the `Pipeline` had more stages, it would call the `LogisticRegressionModel`'s `transform()` method on the dataset before passing the dataset to the next stage. +If the `Pipeline` had more stages, it would call the `LogisticRegressionModel`'s `transform()` +method on the `DataFrame` before passing the `DataFrame` to the next stage. A `Pipeline` is an `Estimator`. -Thus, after a `Pipeline`'s `fit()` method runs, it produces a `PipelineModel` which is a `Transformer`. This `PipelineModel` is used at *test time*; the figure below illustrates this usage. +Thus, after a `Pipeline`'s `fit()` method runs, it produces a `PipelineModel`, which is a +`Transformer`. +This `PipelineModel` is used at *test time*; the figure below illustrates this usage.

In the figure above, the `PipelineModel` has the same number of stages as the original `Pipeline`, but all `Estimator`s in the original `Pipeline` have become `Transformer`s. -When the `PipelineModel`'s `transform()` method is called on a test dataset, the data are passed through the `Pipeline` in order. +When the `PipelineModel`'s `transform()` method is called on a test dataset, the data are passed +through the fitted pipeline in order. Each stage's `transform()` method updates the dataset and passes it to the next stage. `Pipeline`s and `PipelineModel`s help to ensure that training and test data go through identical feature processing steps. @@ -143,40 +161,48 @@ Each stage's `transform()` method updates the dataset and passes it to the next *DAG `Pipeline`s*: A `Pipeline`'s stages are specified as an ordered array. The examples given here are all for linear `Pipeline`s, i.e., `Pipeline`s in which each stage uses data produced by the previous stage. It is possible to create non-linear `Pipeline`s as long as the data flow graph forms a Directed Acyclic Graph (DAG). This graph is currently specified implicitly based on the input and output column names of each stage (generally specified as parameters). If the `Pipeline` forms a DAG, then the stages must be specified in topological order. -*Runtime checking*: Since `Pipeline`s can operate on datasets with varied types, they cannot use compile-time type checking. `Pipeline`s and `PipelineModel`s instead do runtime checking before actually running the `Pipeline`. This type checking is done using the dataset *schema*, a description of the data types of columns in the `DataFrame`. +*Runtime checking*: Since `Pipeline`s can operate on `DataFrame`s with varied types, they cannot use +compile-time type checking. +`Pipeline`s and `PipelineModel`s instead do runtime checking before actually running the `Pipeline`. +This type checking is done using the `DataFrame` *schema*, a description of the data types of columns in the `DataFrame`. ## Parameters Spark ML `Estimator`s and `Transformer`s use a uniform API for specifying parameters. -A [`Param`](api/scala/index.html#org.apache.spark.ml.param.Param) is a named parameter with self-contained documentation. -A [`ParamMap`](api/scala/index.html#org.apache.spark.ml.param.ParamMap) is a set of (parameter, value) pairs. +A `Param` is a named parameter with self-contained documentation. +A `ParamMap` is a set of (parameter, value) pairs. There are two main ways to pass parameters to an algorithm: -1. Set parameters for an instance. E.g., if `lr` is an instance of `LogisticRegression`, one could call `lr.setMaxIter(10)` to make `lr.fit()` use at most 10 iterations. This API resembles the API used in MLlib. +1. Set parameters for an instance. E.g., if `lr` is an instance of `LogisticRegression`, one could + call `lr.setMaxIter(10)` to make `lr.fit()` use at most 10 iterations. + This API resembles the API used in `spark.mllib` package. 2. Pass a `ParamMap` to `fit()` or `transform()`. Any parameters in the `ParamMap` will override parameters previously specified via setter methods. Parameters belong to specific instances of `Estimator`s and `Transformer`s. For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. -# Algorithm Guides +# Algorithm guides There are now several algorithms in the Pipelines API which are not in the `spark.mllib` API, so we link to documentation for them here. These algorithms are mostly feature transformers, which fit naturally into the `Transformer` abstraction in Pipelines, and ensembles, which fit naturally into the `Estimator` abstraction in the Pipelines. -**Pipelines API Algorithm Guides** - -* [Feature Extraction, Transformation, and Selection](ml-features.html) -* [Decision Trees for Classification and Regression](ml-decision-tree.html) +* [Feature extraction, transformation, and selection](ml-features.html) +* [Decision Trees for classification and regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) * [Multilayer perceptron classifier](ml-ann.html) -# Code Examples +# Code examples This section gives code examples illustrating the functionality discussed above. -There is not yet documentation for specific algorithms in Spark ML. For more info, please refer to the [API Documentation](api/scala/index.html#org.apache.spark.ml.package). Spark ML algorithms are currently wrappers for MLlib algorithms, and the [MLlib programming guide](mllib-guide.html) has details on specific algorithms. +For more info, please refer to the API documentation +([Scala](api/scala/index.html#org.apache.spark.ml.package), +[Java](api/java/org/apache/spark/ml/package-summary.html), +and [Python](api/python/pyspark.ml.html)). +Some Spark ML algorithms are wrappers for `spark.mllib` algorithms, and the +[MLlib programming guide](mllib-guide.html) has details on specific algorithms. ## Example: Estimator, Transformer, and Param @@ -627,7 +653,7 @@ sc.stop() -## Example: Model Selection via Cross-Validation +## Example: model selection via cross-validation An important task in ML is *model selection*, or using data to find the best model or parameters for a given task. This is also called *tuning*. `Pipeline`s facilitate model selection by making it easy to tune an entire `Pipeline` at once, rather than tuning each element in the `Pipeline` separately. @@ -873,11 +899,11 @@ jsc.stop(); -## Example: Model Selection via Train Validation Split +## Example: model selection via train validation split In addition to `CrossValidator` Spark also offers `TrainValidationSplit` for hyper-parameter tuning. `TrainValidationSplit` only evaluates each combination of parameters once as opposed to k times in case of `CrossValidator`. It is therefore less expensive, - but will not produce as reliable results when the training dataset is not sufficiently large.. + but will not produce as reliable results when the training dataset is not sufficiently large. `TrainValidationSplit` takes an `Estimator`, a set of `ParamMap`s provided in the `estimatorParamMaps` parameter, and an `Evaluator`. diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 876dcfd40ed7b..257f7cc7603fa 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -14,9 +14,9 @@ primitives and higher-level pipeline APIs. It divides into two packages: * [`spark.mllib`](mllib-guide.html#mllib-types-algorithms-and-utilities) contains the original API - built on top of RDDs. + built on top of [RDDs](programming-guide.html#resilient-distributed-datasets-rdds). * [`spark.ml`](mllib-guide.html#sparkml-high-level-apis-for-ml-pipelines) provides higher-level API - built on top of DataFrames for constructing ML pipelines. + built on top of [DataFrames](sql-programming-guide.html#dataframes) for constructing ML pipelines. Using `spark.ml` is recommended because with DataFrames the API is more versatile and flexible. But we will keep supporting `spark.mllib` along with the development of `spark.ml`. @@ -57,19 +57,19 @@ We list major functionality from both below, with links to detailed guides. * [FP-growth](mllib-frequent-pattern-mining.html#fp-growth) * [association rules](mllib-frequent-pattern-mining.html#association-rules) * [PrefixSpan](mllib-frequent-pattern-mining.html#prefix-span) -* [Evaluation Metrics](mllib-evaluation-metrics.html) +* [Evaluation metrics](mllib-evaluation-metrics.html) +* [PMML model export](mllib-pmml-model-export.html) * [Optimization (developer)](mllib-optimization.html) * [stochastic gradient descent](mllib-optimization.html#stochastic-gradient-descent-sgd) * [limited-memory BFGS (L-BFGS)](mllib-optimization.html#limited-memory-bfgs-l-bfgs) -* [PMML model export](mllib-pmml-model-export.html) # spark.ml: high-level APIs for ML pipelines **[spark.ml programming guide](ml-guide.html)** provides an overview of the Pipelines API and major concepts. It also contains sections on using algorithms within the Pipelines API, for example: -* [Feature Extraction, Transformation, and Selection](ml-features.html) -* [Decision Trees for Classification and Regression](ml-decision-tree.html) +* [Feature extraction, transformation, and selection](ml-features.html) +* [Decision trees for classification and regression](ml-decision-tree.html) * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) * [Multilayer perceptron classifier](ml-ann.html) From ca69fc8efda8a3e5442ffa16692a2b1eb86b7673 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sat, 29 Aug 2015 23:57:09 -0700 Subject: [PATCH 145/802] [SPARK-10331] [MLLIB] Update example code in ml-guide * The example code was added in 1.2, before `createDataFrame`. This PR switches to `createDataFrame`. Java code still uses JavaBean. * assume `sqlContext` is available * fix some minor issues from previous code review jkbradley srowen feynmanliang Author: Xiangrui Meng Closes #8518 from mengxr/SPARK-10331. --- docs/ml-guide.md | 362 +++++++++++++++++++---------------------------- 1 file changed, 147 insertions(+), 215 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 4ba07542bfb40..78c93a95c7807 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -212,26 +212,18 @@ This example covers the concepts of `Estimator`, `Transformer`, and `Param`.

{% highlight scala %} -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.param.ParamMap import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.Row -val conf = new SparkConf().setAppName("SimpleParamsExample") -val sc = new SparkContext(conf) -val sqlContext = new SQLContext(sc) -import sqlContext.implicits._ - -// Prepare training data. -// We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes -// into DataFrames, where it uses the case class metadata to infer the schema. -val training = sc.parallelize(Seq( - LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), - LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), - LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), - LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)))) +// Prepare training data from a list of (label, features) tuples. +val training = sqlContext.createDataFrame(Seq( + (1.0, Vectors.dense(0.0, 1.1, 0.1)), + (0.0, Vectors.dense(2.0, 1.0, -1.0)), + (0.0, Vectors.dense(2.0, 1.3, 1.0)), + (1.0, Vectors.dense(0.0, 1.2, -0.5)) +)).toDF("label", "features") // Create a LogisticRegression instance. This instance is an Estimator. val lr = new LogisticRegression() @@ -243,7 +235,7 @@ lr.setMaxIter(10) .setRegParam(0.01) // Learn a LogisticRegression model. This uses the parameters stored in lr. -val model1 = lr.fit(training.toDF) +val model1 = lr.fit(training) // Since model1 is a Model (i.e., a Transformer produced by an Estimator), // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this @@ -253,8 +245,8 @@ println("Model 1 was fit using parameters: " + model1.parent.extractParamMap) // We may alternatively specify parameters using a ParamMap, // which supports several methods for specifying parameters. val paramMap = ParamMap(lr.maxIter -> 20) -paramMap.put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. -paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. + .put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. + .put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. // One can also combine ParamMaps. val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name @@ -262,27 +254,27 @@ val paramMapCombined = paramMap ++ paramMap2 // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. -val model2 = lr.fit(training.toDF, paramMapCombined) +val model2 = lr.fit(training, paramMapCombined) println("Model 2 was fit using parameters: " + model2.parent.extractParamMap) // Prepare test data. -val test = sc.parallelize(Seq( - LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), - LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), - LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) +val test = sqlContext.createDataFrame(Seq( + (1.0, Vectors.dense(-1.0, 1.5, 1.3)), + (0.0, Vectors.dense(3.0, 2.0, -0.1)), + (1.0, Vectors.dense(0.0, 2.2, -1.5)) +)).toDF("label", "features") // Make predictions on test data using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. // Note that model2.transform() outputs a 'myProbability' column instead of the usual // 'probability' column since we renamed the lr.probabilityCol parameter previously. -model2.transform(test.toDF) +model2.transform(test) .select("features", "label", "myProbability", "prediction") .collect() .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) => println(s"($features, $label) -> prob=$prob, prediction=$prediction") } -sc.stop() {% endhighlight %}
@@ -291,30 +283,23 @@ sc.stop() import java.util.Arrays; import java.util.List; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.classification.LogisticRegressionModel; import org.apache.spark.ml.param.ParamMap; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.Row; -SparkConf conf = new SparkConf().setAppName("JavaSimpleParamsExample"); -JavaSparkContext jsc = new JavaSparkContext(conf); -SQLContext jsql = new SQLContext(jsc); - // Prepare training data. // We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans // into DataFrames, where it uses the bean metadata to infer the schema. -List localTraining = Arrays.asList( +DataFrame training = sqlContext.createDataFrame(Arrays.asList( new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), - new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); -DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class); + new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)) +), LabeledPoint.class); // Create a LogisticRegression instance. This instance is an Estimator. LogisticRegression lr = new LogisticRegression(); @@ -334,14 +319,14 @@ LogisticRegressionModel model1 = lr.fit(training); System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap()); // We may alternatively specify parameters using a ParamMap. -ParamMap paramMap = new ParamMap(); -paramMap.put(lr.maxIter().w(20)); // Specify 1 Param. -paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter. -paramMap.put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. +ParamMap paramMap = new ParamMap() + .put(lr.maxIter().w(20)) // Specify 1 Param. + .put(lr.maxIter(), 30) // This overwrites the original maxIter. + .put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params. // One can also combine ParamMaps. -ParamMap paramMap2 = new ParamMap(); -paramMap2.put(lr.probabilityCol().w("myProbability")); // Change output column name +ParamMap paramMap2 = new ParamMap() + .put(lr.probabilityCol().w("myProbability")); // Change output column name ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); // Now learn a new model using the paramMapCombined parameters. @@ -350,11 +335,11 @@ LogisticRegressionModel model2 = lr.fit(training, paramMapCombined); System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap()); // Prepare test documents. -List localTest = Arrays.asList( - new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), - new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), - new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); -DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); +DataFrame test = sqlContext.createDataFrame(Arrays.asList( + new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), + new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), + new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)) +), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. @@ -366,28 +351,21 @@ for (Row r: results.select("features", "label", "myProbability", "prediction").c + ", prediction=" + r.get(3)); } -jsc.stop(); {% endhighlight %}
{% highlight python %} -from pyspark import SparkContext -from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.linalg import Vectors from pyspark.ml.classification import LogisticRegression from pyspark.ml.param import Param, Params -from pyspark.sql import Row, SQLContext -sc = SparkContext(appName="SimpleParamsExample") -sqlContext = SQLContext(sc) - -# Prepare training data. -# We use LabeledPoint. -# Spark SQL can convert RDDs of LabeledPoints into DataFrames. -training = sc.parallelize([LabeledPoint(1.0, [0.0, 1.1, 0.1]), - LabeledPoint(0.0, [2.0, 1.0, -1.0]), - LabeledPoint(0.0, [2.0, 1.3, 1.0]), - LabeledPoint(1.0, [0.0, 1.2, -0.5])]) +# Prepare training data from a list of (label, features) tuples. +training = sqlContext.createDataFrame([ + (1.0, Vectors.dense([0.0, 1.1, 0.1])), + (0.0, Vectors.dense([2.0, 1.0, -1.0])), + (0.0, Vectors.dense([2.0, 1.3, 1.0])), + (1.0, Vectors.dense([0.0, 1.2, -0.5]))], ["label", "features"]) # Create a LogisticRegression instance. This instance is an Estimator. lr = LogisticRegression(maxIter=10, regParam=0.01) @@ -395,7 +373,7 @@ lr = LogisticRegression(maxIter=10, regParam=0.01) print "LogisticRegression parameters:\n" + lr.explainParams() + "\n" # Learn a LogisticRegression model. This uses the parameters stored in lr. -model1 = lr.fit(training.toDF()) +model1 = lr.fit(training) # Since model1 is a Model (i.e., a transformer produced by an Estimator), # we can view the parameters it used during fit(). @@ -416,25 +394,25 @@ paramMapCombined.update(paramMap2) # Now learn a new model using the paramMapCombined parameters. # paramMapCombined overrides all parameters set earlier via lr.set* methods. -model2 = lr.fit(training.toDF(), paramMapCombined) +model2 = lr.fit(training, paramMapCombined) print "Model 2 was fit using parameters: " print model2.extractParamMap() # Prepare test data -test = sc.parallelize([LabeledPoint(1.0, [-1.0, 1.5, 1.3]), - LabeledPoint(0.0, [ 3.0, 2.0, -0.1]), - LabeledPoint(1.0, [ 0.0, 2.2, -1.5])]) +test = sqlContext.createDataFrame([ + (1.0, Vectors.dense([-1.0, 1.5, 1.3])), + (0.0, Vectors.dense([3.0, 2.0, -0.1])), + (1.0, Vectors.dense([0.0, 2.2, -1.5]))], ["label", "features"]) # Make predictions on test data using the Transformer.transform() method. # LogisticRegression.transform will only use the 'features' column. # Note that model2.transform() outputs a "myProbability" column instead of the usual # 'probability' column since we renamed the lr.probabilityCol parameter previously. -prediction = model2.transform(test.toDF()) +prediction = model2.transform(test) selected = prediction.select("features", "label", "myProbability", "prediction") for row in selected.collect(): print row -sc.stop() {% endhighlight %}
@@ -448,30 +426,19 @@ This example follows the simple text document `Pipeline` illustrated in the figu
{% highlight scala %} -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.feature.{HashingTF, Tokenizer} import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.{Row, SQLContext} - -// Labeled and unlabeled instance types. -// Spark SQL can infer schema from case classes. -case class LabeledDocument(id: Long, text: String, label: Double) -case class Document(id: Long, text: String) +import org.apache.spark.sql.Row -// Set up contexts. Import implicit conversions to DataFrame from sqlContext. -val conf = new SparkConf().setAppName("SimpleTextClassificationPipeline") -val sc = new SparkContext(conf) -val sqlContext = new SQLContext(sc) -import sqlContext.implicits._ - -// Prepare training documents, which are labeled. -val training = sc.parallelize(Seq( - LabeledDocument(0L, "a b c d e spark", 1.0), - LabeledDocument(1L, "b d", 0.0), - LabeledDocument(2L, "spark f g h", 1.0), - LabeledDocument(3L, "hadoop mapreduce", 0.0))) +// Prepare training documents from a list of (id, text, label) tuples. +val training = sqlContext.createDataFrame(Seq( + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0) +)).toDF("id", "text", "label") // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. val tokenizer = new Tokenizer() @@ -488,14 +455,15 @@ val pipeline = new Pipeline() .setStages(Array(tokenizer, hashingTF, lr)) // Fit the pipeline to training documents. -val model = pipeline.fit(training.toDF) +val model = pipeline.fit(training) -// Prepare test documents, which are unlabeled. -val test = sc.parallelize(Seq( - Document(4L, "spark i j k"), - Document(5L, "l m n"), - Document(6L, "mapreduce spark"), - Document(7L, "apache hadoop"))) +// Prepare test documents, which are unlabeled (id, text) tuples. +val test = sqlContext.createDataFrame(Seq( + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop") +)).toDF("id", "text") // Make predictions on test documents. model.transform(test.toDF) @@ -505,7 +473,6 @@ model.transform(test.toDF) println(s"($id, $text) --> prob=$prob, prediction=$prediction") } -sc.stop() {% endhighlight %}
@@ -514,8 +481,6 @@ sc.stop() import java.util.Arrays; import java.util.List; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineModel; import org.apache.spark.ml.PipelineStage; @@ -524,7 +489,6 @@ import org.apache.spark.ml.feature.HashingTF; import org.apache.spark.ml.feature.Tokenizer; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; // Labeled and unlabeled instance types. // Spark SQL can infer schema from Java Beans. @@ -556,18 +520,13 @@ public class LabeledDocument extends Document implements Serializable { public void setLabel(double label) { this.label = label; } } -// Set up contexts. -SparkConf conf = new SparkConf().setAppName("JavaSimpleTextClassificationPipeline"); -JavaSparkContext jsc = new JavaSparkContext(conf); -SQLContext jsql = new SQLContext(jsc); - // Prepare training documents, which are labeled. -List localTraining = Arrays.asList( +DataFrame training = sqlContext.createDataFrame(Arrays.asList( new LabeledDocument(0L, "a b c d e spark", 1.0), new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), - new LabeledDocument(3L, "hadoop mapreduce", 0.0)); -DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); + new LabeledDocument(3L, "hadoop mapreduce", 0.0) +), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -587,12 +546,12 @@ Pipeline pipeline = new Pipeline() PipelineModel model = pipeline.fit(training); // Prepare test documents, which are unlabeled. -List localTest = Arrays.asList( +DataFrame test = sqlContext.createDataFrame(Arrays.asList( new Document(4L, "spark i j k"), new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), - new Document(7L, "apache hadoop")); -DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); + new Document(7L, "apache hadoop") +), Document.class); // Make predictions on test documents. DataFrame predictions = model.transform(test); @@ -601,28 +560,23 @@ for (Row r: predictions.select("id", "text", "probability", "prediction").collec + ", prediction=" + r.get(3)); } -jsc.stop(); {% endhighlight %}
{% highlight python %} -from pyspark import SparkContext from pyspark.ml import Pipeline from pyspark.ml.classification import LogisticRegression from pyspark.ml.feature import HashingTF, Tokenizer -from pyspark.sql import Row, SQLContext - -sc = SparkContext(appName="SimpleTextClassificationPipeline") -sqlContext = SQLContext(sc) +from pyspark.sql import Row -# Prepare training documents, which are labeled. +# Prepare training documents from a list of (id, text, label) tuples. LabeledDocument = Row("id", "text", "label") -training = sc.parallelize([(0L, "a b c d e spark", 1.0), - (1L, "b d", 0.0), - (2L, "spark f g h", 1.0), - (3L, "hadoop mapreduce", 0.0)]) \ - .map(lambda x: LabeledDocument(*x)).toDF() +training = sqlContext.createDataFrame([ + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0)], ["id", "text", "label"]) # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. tokenizer = Tokenizer(inputCol="text", outputCol="words") @@ -633,13 +587,12 @@ pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) # Fit the pipeline to training documents. model = pipeline.fit(training) -# Prepare test documents, which are unlabeled. -Document = Row("id", "text") -test = sc.parallelize([(4L, "spark i j k"), - (5L, "l m n"), - (6L, "mapreduce spark"), - (7L, "apache hadoop")]) \ - .map(lambda x: Document(*x)).toDF() +# Prepare test documents, which are unlabeled (id, text) tuples. +test = sqlContext.createDataFrame([ + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop")], ["id", "text"]) # Make predictions on test documents and print columns of interest. prediction = model.transform(test) @@ -647,7 +600,6 @@ selected = prediction.select("id", "text", "prediction") for row in selected.collect(): print(row) -sc.stop() {% endhighlight %}
@@ -664,8 +616,8 @@ Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/ The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.RegressionEvaluator) for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.BinaryClassificationEvaluator) -for binary data or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.MultiClassClassificationEvaluator) -for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the setMetric +for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.MultiClassClassificationEvaluator) +for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the `setMetric` method in each of these evaluators. The `ParamMap` which produces the best evaluation metric (averaged over the `$k$` folds) is selected as the best model. @@ -684,39 +636,29 @@ However, it is also a well-established method for choosing parameters which is m
{% highlight scala %} -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.ml.feature.{HashingTF, Tokenizer} import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.{Row, SQLContext} - -// Labeled and unlabeled instance types. -// Spark SQL can infer schema from case classes. -case class LabeledDocument(id: Long, text: String, label: Double) -case class Document(id: Long, text: String) - -val conf = new SparkConf().setAppName("CrossValidatorExample") -val sc = new SparkContext(conf) -val sqlContext = new SQLContext(sc) -import sqlContext.implicits._ - -// Prepare training documents, which are labeled. -val training = sc.parallelize(Seq( - LabeledDocument(0L, "a b c d e spark", 1.0), - LabeledDocument(1L, "b d", 0.0), - LabeledDocument(2L, "spark f g h", 1.0), - LabeledDocument(3L, "hadoop mapreduce", 0.0), - LabeledDocument(4L, "b spark who", 1.0), - LabeledDocument(5L, "g d a y", 0.0), - LabeledDocument(6L, "spark fly", 1.0), - LabeledDocument(7L, "was mapreduce", 0.0), - LabeledDocument(8L, "e spark program", 1.0), - LabeledDocument(9L, "a e c l", 0.0), - LabeledDocument(10L, "spark compile", 1.0), - LabeledDocument(11L, "hadoop software", 0.0))) +import org.apache.spark.sql.Row + +// Prepare training data from a list of (id, text, label) tuples. +val training = sqlContext.createDataFrame(Seq( + (0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0), + (4L, "b spark who", 1.0), + (5L, "g d a y", 0.0), + (6L, "spark fly", 1.0), + (7L, "was mapreduce", 0.0), + (8L, "e spark program", 1.0), + (9L, "a e c l", 0.0), + (10L, "spark compile", 1.0), + (11L, "hadoop software", 0.0) +)).toDF("id", "text", "label") // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. val tokenizer = new Tokenizer() @@ -730,15 +672,6 @@ val lr = new LogisticRegression() val pipeline = new Pipeline() .setStages(Array(tokenizer, hashingTF, lr)) -// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. -// This will allow us to jointly choose parameters for all Pipeline stages. -// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -// Note that the evaluator here is a BinaryClassificationEvaluator and the default metric -// used is areaUnderROC. -val crossval = new CrossValidator() - .setEstimator(pipeline) - .setEvaluator(new BinaryClassificationEvaluator) - // We use a ParamGridBuilder to construct a grid of parameters to search over. // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. @@ -746,28 +679,37 @@ val paramGrid = new ParamGridBuilder() .addGrid(hashingTF.numFeatures, Array(10, 100, 1000)) .addGrid(lr.regParam, Array(0.1, 0.01)) .build() -crossval.setEstimatorParamMaps(paramGrid) -crossval.setNumFolds(2) // Use 3+ in practice + +// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. +// This will allow us to jointly choose parameters for all Pipeline stages. +// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric +// is areaUnderROC. +val cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(new BinaryClassificationEvaluator) + .setEstimatorParamMaps(paramGrid) + .setNumFolds(2) // Use 3+ in practice // Run cross-validation, and choose the best set of parameters. -val cvModel = crossval.fit(training.toDF) +val cvModel = cv.fit(training) -// Prepare test documents, which are unlabeled. -val test = sc.parallelize(Seq( - Document(4L, "spark i j k"), - Document(5L, "l m n"), - Document(6L, "mapreduce spark"), - Document(7L, "apache hadoop"))) +// Prepare test documents, which are unlabeled (id, text) tuples. +val test = sqlContext.createDataFrame(Seq( + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop") +)).toDF("id", "text") // Make predictions on test documents. cvModel uses the best model found (lrModel). -cvModel.transform(test.toDF) +cvModel.transform(test) .select("id", "text", "probability", "prediction") .collect() .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => - println(s"($id, $text) --> prob=$prob, prediction=$prediction") -} + println(s"($id, $text) --> prob=$prob, prediction=$prediction") + } -sc.stop() {% endhighlight %}
@@ -776,8 +718,6 @@ sc.stop() import java.util.Arrays; import java.util.List; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineStage; import org.apache.spark.ml.classification.LogisticRegression; @@ -790,7 +730,6 @@ import org.apache.spark.ml.tuning.CrossValidatorModel; import org.apache.spark.ml.tuning.ParamGridBuilder; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; // Labeled and unlabeled instance types. // Spark SQL can infer schema from Java Beans. @@ -822,12 +761,9 @@ public class LabeledDocument extends Document implements Serializable { public void setLabel(double label) { this.label = label; } } -SparkConf conf = new SparkConf().setAppName("JavaCrossValidatorExample"); -JavaSparkContext jsc = new JavaSparkContext(conf); -SQLContext jsql = new SQLContext(jsc); // Prepare training documents, which are labeled. -List localTraining = Arrays.asList( +DataFrame training = sqlContext.createDataFrame(Arrays.asList( new LabeledDocument(0L, "a b c d e spark", 1.0), new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), @@ -839,8 +775,8 @@ List localTraining = Arrays.asList( new LabeledDocument(8L, "e spark program", 1.0), new LabeledDocument(9L, "a e c l", 0.0), new LabeledDocument(10L, "spark compile", 1.0), - new LabeledDocument(11L, "hadoop software", 0.0)); -DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); + new LabeledDocument(11L, "hadoop software", 0.0) +), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -856,15 +792,6 @@ LogisticRegression lr = new LogisticRegression() Pipeline pipeline = new Pipeline() .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); -// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. -// This will allow us to jointly choose parameters for all Pipeline stages. -// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. -// Note that the evaluator here is a BinaryClassificationEvaluator and the default metric -// used is areaUnderROC. -CrossValidator crossval = new CrossValidator() - .setEstimator(pipeline) - .setEvaluator(new BinaryClassificationEvaluator()); - // We use a ParamGridBuilder to construct a grid of parameters to search over. // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. @@ -872,19 +799,28 @@ ParamMap[] paramGrid = new ParamGridBuilder() .addGrid(hashingTF.numFeatures(), new int[]{10, 100, 1000}) .addGrid(lr.regParam(), new double[]{0.1, 0.01}) .build(); -crossval.setEstimatorParamMaps(paramGrid); -crossval.setNumFolds(2); // Use 3+ in practice + +// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. +// This will allow us to jointly choose parameters for all Pipeline stages. +// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. +// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric +// is areaUnderROC. +CrossValidator cv = new CrossValidator() + .setEstimator(pipeline) + .setEvaluator(new BinaryClassificationEvaluator()) + .setEstimatorParamMaps(paramGrid) + .setNumFolds(2); // Use 3+ in practice // Run cross-validation, and choose the best set of parameters. -CrossValidatorModel cvModel = crossval.fit(training); +CrossValidatorModel cvModel = cv.fit(training); // Prepare test documents, which are unlabeled. -List localTest = Arrays.asList( +DataFrame test = sqlContext.createDataFrame(Arrays.asList( new Document(4L, "spark i j k"), new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), - new Document(7L, "apache hadoop")); -DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); + new Document(7L, "apache hadoop") +), Document.class); // Make predictions on test documents. cvModel uses the best model found (lrModel). DataFrame predictions = cvModel.transform(test); @@ -893,7 +829,6 @@ for (Row r: predictions.select("id", "text", "probability", "prediction").collec + ", prediction=" + r.get(3)); } -jsc.stop(); {% endhighlight %} @@ -935,7 +870,7 @@ val lr = new LinearRegression() // the evaluator. val paramGrid = new ParamGridBuilder() .addGrid(lr.regParam, Array(0.1, 0.01)) - .addGrid(lr.fitIntercept, Array(true, false)) + .addGrid(lr.fitIntercept) .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)) .build() @@ -945,9 +880,8 @@ val trainValidationSplit = new TrainValidationSplit() .setEstimator(lr) .setEvaluator(new RegressionEvaluator) .setEstimatorParamMaps(paramGrid) - -// 80% of the data will be used for training and the remaining 20% for validation. -trainValidationSplit.setTrainRatio(0.8) + // 80% of the data will be used for training and the remaining 20% for validation. + .setTrainRatio(0.8) // Run train validation split, and choose the best set of parameters. val model = trainValidationSplit.fit(training) @@ -972,12 +906,12 @@ import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.rdd.RDD; import org.apache.spark.sql.DataFrame; -DataFrame data = jsql.createDataFrame( +DataFrame data = sqlContext.createDataFrame( MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"), LabeledPoint.class); // Prepare training and test data. -DataFrame[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345); +DataFrame[] splits = data.randomSplit(new double[] {0.9, 0.1}, 12345); DataFrame training = splits[0]; DataFrame test = splits[1]; @@ -997,10 +931,8 @@ ParamMap[] paramGrid = new ParamGridBuilder() TrainValidationSplit trainValidationSplit = new TrainValidationSplit() .setEstimator(lr) .setEvaluator(new RegressionEvaluator()) - .setEstimatorParamMaps(paramGrid); - -// 80% of the data will be used for training and the remaining 20% for validation. -trainValidationSplit.setTrainRatio(0.8); + .setEstimatorParamMaps(paramGrid) + .setTrainRatio(0.8); // 80% for training and the remaining 20% for validation // Run train validation split, and choose the best set of parameters. TrainValidationSplitModel model = trainValidationSplit.fit(training); From 1bfd9347822df65e76201c4c471a26488d722319 Mon Sep 17 00:00:00 2001 From: ihainan Date: Sun, 30 Aug 2015 08:26:14 +0100 Subject: [PATCH 146/802] [SPARK-10184] [CORE] Optimization for bounds determination in RangePartitioner JIRA Issue: https://issues.apache.org/jira/browse/SPARK-10184 Change `cumWeight > target` to `cumWeight >= target` in `RangePartitioner.determineBounds` method to make the output partitions more balanced. Author: ihainan Closes #8397 from ihainan/opt_for_rangepartitioner. --- core/src/main/scala/org/apache/spark/Partitioner.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 4b9d59975bdc2..29e581bb57cbc 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -291,7 +291,7 @@ private[spark] object RangePartitioner { while ((i < numCandidates) && (j < partitions - 1)) { val (key, weight) = ordered(i) cumWeight += weight - if (cumWeight > target) { + if (cumWeight >= target) { // Skip duplicate values. if (previousBound.isEmpty || ordering.gt(key, previousBound.get)) { bounds += key From 8d2ab75d3b71b632f2394f2453af32f417cb45e5 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Sun, 30 Aug 2015 12:21:15 -0700 Subject: [PATCH 147/802] [SPARK-10353] [MLLIB] BLAS gemm not scaling when beta = 0.0 for some subset of matrix multiplications mengxr jkbradley rxin It would be great if this fix made it into RC3! Author: Burak Yavuz Closes #8525 from brkyvz/blas-scaling. --- .../org/apache/spark/mllib/linalg/BLAS.scala | 26 +++++++------------ .../apache/spark/mllib/linalg/BLASSuite.scala | 5 ++++ 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index bbbcc8436b7c2..ab475af264dd3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -305,6 +305,8 @@ private[spark] object BLAS extends Serializable with Logging { "The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.") if (alpha == 0.0 && beta == 1.0) { logDebug("gemm: alpha is equal to 0 and beta is equal to 1. Returning C.") + } else if (alpha == 0.0) { + f2jBLAS.dscal(C.values.length, beta, C.values, 1) } else { A match { case sparse: SparseMatrix => gemm(alpha, sparse, B, beta, C) @@ -408,8 +410,8 @@ private[spark] object BLAS extends Serializable with Logging { } } } else { - // Scale matrix first if `beta` is not equal to 0.0 - if (beta != 0.0) { + // Scale matrix first if `beta` is not equal to 1.0 + if (beta != 1.0) { f2jBLAS.dscal(C.values.length, beta, C.values, 1) } // Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of @@ -470,8 +472,10 @@ private[spark] object BLAS extends Serializable with Logging { s"The columns of A don't match the number of elements of x. A: ${A.numCols}, x: ${x.size}") require(A.numRows == y.size, s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}") - if (alpha == 0.0) { - logDebug("gemv: alpha is equal to 0. Returning y.") + if (alpha == 0.0 && beta == 1.0) { + logDebug("gemv: alpha is equal to 0 and beta is equal to 1. Returning y.") + } else if (alpha == 0.0) { + scal(beta, y) } else { (A, x) match { case (smA: SparseMatrix, dvx: DenseVector) => @@ -526,11 +530,6 @@ private[spark] object BLAS extends Serializable with Logging { val xValues = x.values val yValues = y.values - if (alpha == 0.0) { - scal(beta, y) - return - } - if (A.isTransposed) { var rowCounterForA = 0 while (rowCounterForA < mA) { @@ -581,11 +580,6 @@ private[spark] object BLAS extends Serializable with Logging { val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices - if (alpha == 0.0) { - scal(beta, y) - return - } - if (A.isTransposed) { var rowCounter = 0 while (rowCounter < mA) { @@ -604,7 +598,7 @@ private[spark] object BLAS extends Serializable with Logging { rowCounter += 1 } } else { - scal(beta, y) + if (beta != 1.0) scal(beta, y) var colCounterForA = 0 var k = 0 @@ -659,7 +653,7 @@ private[spark] object BLAS extends Serializable with Logging { rowCounter += 1 } } else { - scal(beta, y) + if (beta != 1.0) scal(beta, y) // Perform matrix-vector multiplication and add to y var colCounterForA = 0 while (colCounterForA < nA) { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index d119e0b50a393..8db5c8424abe9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -204,6 +204,7 @@ class BLASSuite extends SparkFunSuite { val C14 = C1.copy val C15 = C1.copy val C16 = C1.copy + val C17 = C1.copy val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0)) val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0)) val expected4 = new DenseMatrix(4, 2, Array(5.0, 0.0, 10.0, 5.0, 0.0, 0.0, 5.0, 0.0)) @@ -217,6 +218,10 @@ class BLASSuite extends SparkFunSuite { assert(C2 ~== expected2 absTol 1e-15) assert(C3 ~== expected3 absTol 1e-15) assert(C4 ~== expected3 absTol 1e-15) + gemm(1.0, dA, B, 0.0, C17) + assert(C17 ~== expected absTol 1e-15) + gemm(1.0, sA, B, 0.0, C17) + assert(C17 ~== expected absTol 1e-15) withClue("columns of A don't match the rows of B") { intercept[Exception] { From 35e896a79bb5e72d63b82b047f46f4f6fa2e1970 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sun, 30 Aug 2015 21:39:16 -0700 Subject: [PATCH 148/802] SPARK-9545, SPARK-9547: Use Maven in PRB if title contains "[test-maven]" This is just some small glue code to actually make use of the AMPLAB_JENKINS_BUILD_TOOL switch. As far as I can tell, we actually don't currently use the Maven support in the tool even though it exists. This patch switches to Maven when the PR title contains "test-maven". There are a few small other pieces of cleanup in the patch as well. Author: Patrick Wendell Closes #7878 from pwendell/maven-tests. --- dev/run-tests-jenkins | 18 ++++++++++++++++-- dev/run-tests.py | 28 ++++++++++++++++++++++++++-- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 39cf54f78104c..3be78575e70f1 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -164,8 +164,9 @@ pr_message="" current_pr_head="`git rev-parse HEAD`" echo "HEAD: `git rev-parse HEAD`" -echo "GHPRB: $ghprbActualCommit" -echo "SHA1: $sha1" +echo "\$ghprbActualCommit: $ghprbActualCommit" +echo "\$sha1: $sha1" +echo "\$ghprbPullTitle: $ghprbPullTitle" # Run pull request tests for t in "${PR_TESTS[@]}"; do @@ -189,6 +190,19 @@ done { # Marks this build is a pull request build. export AMP_JENKINS_PRB=true + if [[ $ghprbPullTitle == *"test-maven"* ]]; then + export AMPLAB_JENKINS_BUILD_TOOL="maven" + fi + if [[ $ghprbPullTitle == *"test-hadoop1.0"* ]]; then + export AMPLAB_JENKINS_BUILD_PROFILE="hadoop1.0" + elif [[ $ghprbPullTitle == *"test-hadoop2.0"* ]]; then + export AMPLAB_JENKINS_BUILD_PROFILE="hadoop2.0" + elif [[ $ghprbPullTitle == *"test-hadoop2.2"* ]]; then + export AMPLAB_JENKINS_BUILD_PROFILE="hadoop2.2" + elif [[ $ghprbPullTitle == *"test-hadoop2.3"* ]]; then + export AMPLAB_JENKINS_BUILD_PROFILE="hadoop2.3" + fi + timeout "${TESTS_TIMEOUT}" ./dev/run-tests test_result="$?" diff --git a/dev/run-tests.py b/dev/run-tests.py index 4fd703a7c219f..d8b22e1665e7b 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -21,6 +21,7 @@ import itertools from optparse import OptionParser import os +import random import re import sys import subprocess @@ -239,11 +240,32 @@ def build_spark_documentation(): os.chdir(SPARK_HOME) +def get_zinc_port(): + """ + Get a randomized port on which to start Zinc + """ + return random.randrange(3030, 4030) + + +def kill_zinc_on_port(zinc_port): + """ + Kill the Zinc process running on the given port, if one exists. + """ + cmd = ("/usr/sbin/lsof -P |grep %s | grep LISTEN " + "| awk '{ print $2; }' | xargs kill") % zinc_port + subprocess.check_call(cmd, shell=True) + + def exec_maven(mvn_args=()): """Will call Maven in the current directory with the list of mvn_args passed in and returns the subprocess for any further processing""" - run_cmd([os.path.join(SPARK_HOME, "build", "mvn")] + mvn_args) + zinc_port = get_zinc_port() + os.environ["ZINC_PORT"] = "%s" % zinc_port + zinc_flag = "-DzincPort=%s" % zinc_port + flags = [os.path.join(SPARK_HOME, "build", "mvn"), "--force", zinc_flag] + run_cmd(flags + mvn_args) + kill_zinc_on_port(zinc_port) def exec_sbt(sbt_args=()): @@ -514,7 +536,9 @@ def main(): build_apache_spark(build_tool, hadoop_version) # backwards compatibility checks - detect_binary_inop_with_mima() + if build_tool == "sbt": + # Note: compatiblity tests only supported in sbt for now + detect_binary_inop_with_mima() # run the test suites run_scala_tests(build_tool, hadoop_version, test_modules) From 8694c3ad7dcafca9563649e93b7a08076748d6f2 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Sun, 30 Aug 2015 23:12:56 -0700 Subject: [PATCH 149/802] [SPARK-10351] [SQL] Fixes UTF8String.fromAddress to handle off-heap memory CC rxin marmbrus Author: Feynman Liang Closes #8523 from feynmanliang/SPARK-10351. --- .../test/scala/org/apache/spark/sql/UnsafeRowSuite.scala | 9 +++++---- .../java/org/apache/spark/unsafe/types/UTF8String.java | 6 +----- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index 219435dff5bc8..2476b10e3cf9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -43,12 +43,12 @@ class UnsafeRowSuite extends SparkFunSuite { val arrayBackedUnsafeRow: UnsafeRow = UnsafeProjection.create(Array[DataType](StringType, StringType, IntegerType)).apply(row) assert(arrayBackedUnsafeRow.getBaseObject.isInstanceOf[Array[Byte]]) - val bytesFromArrayBackedRow: Array[Byte] = { + val (bytesFromArrayBackedRow, field0StringFromArrayBackedRow): (Array[Byte], String) = { val baos = new ByteArrayOutputStream() arrayBackedUnsafeRow.writeToStream(baos, null) - baos.toByteArray + (baos.toByteArray, arrayBackedUnsafeRow.getString(0)) } - val bytesFromOffheapRow: Array[Byte] = { + val (bytesFromOffheapRow, field0StringFromOffheapRow): (Array[Byte], String) = { val offheapRowPage = MemoryAllocator.UNSAFE.allocate(arrayBackedUnsafeRow.getSizeInBytes) try { Platform.copyMemory( @@ -69,13 +69,14 @@ class UnsafeRowSuite extends SparkFunSuite { val baos = new ByteArrayOutputStream() val writeBuffer = new Array[Byte](1024) offheapUnsafeRow.writeToStream(baos, writeBuffer) - baos.toByteArray + (baos.toByteArray, offheapUnsafeRow.getString(0)) } finally { MemoryAllocator.UNSAFE.free(offheapRowPage) } } assert(bytesFromArrayBackedRow === bytesFromOffheapRow) + assert(field0StringFromArrayBackedRow === field0StringFromOffheapRow) } test("calling getDouble() and getFloat() on null columns") { diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index cbcab958c05a9..216aeea60d1c8 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -90,11 +90,7 @@ public static UTF8String fromBytes(byte[] bytes, int offset, int numBytes) { * Creates an UTF8String from given address (base and offset) and length. */ public static UTF8String fromAddress(Object base, long offset, int numBytes) { - if (base != null) { - return new UTF8String(base, offset, numBytes); - } else { - return null; - } + return new UTF8String(base, offset, numBytes); } /** From f0f563a3c43fc9683e6920890cce44611c0c5f4b Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 30 Aug 2015 23:20:03 -0700 Subject: [PATCH 150/802] [SPARK-100354] [MLLIB] fix some apparent memory issues in k-means|| initializaiton * do not cache first cost RDD * change following cost RDD cache level to MEMORY_AND_DISK * remove Vector wrapper to save a object per instance Further improvements will be addressed in SPARK-10329 cc: yu-iskw HuJiayin Author: Xiangrui Meng Closes #8526 from mengxr/SPARK-10354. --- .../spark/mllib/clustering/KMeans.scala | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 46920fffe6e1a..7168aac32c997 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -369,7 +369,7 @@ class KMeans private ( : Array[Array[VectorWithNorm]] = { // Initialize empty centers and point costs. val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm]) - var costs = data.map(_ => Vectors.dense(Array.fill(runs)(Double.PositiveInfinity))).cache() + var costs = data.map(_ => Array.fill(runs)(Double.PositiveInfinity)) // Initialize each run's first center to a random point. val seed = new XORShiftRandom(this.seed).nextInt() @@ -394,21 +394,28 @@ class KMeans private ( val bcNewCenters = data.context.broadcast(newCenters) val preCosts = costs costs = data.zip(preCosts).map { case (point, cost) => - Vectors.dense( Array.tabulate(runs) { r => math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r)) - }) - }.cache() + } + }.persist(StorageLevel.MEMORY_AND_DISK) val sumCosts = costs - .aggregate(Vectors.zeros(runs))( + .aggregate(new Array[Double](runs))( seqOp = (s, v) => { // s += v - axpy(1.0, v, s) + var r = 0 + while (r < runs) { + s(r) += v(r) + r += 1 + } s }, combOp = (s0, s1) => { // s0 += s1 - axpy(1.0, s1, s0) + var r = 0 + while (r < runs) { + s0(r) += s1(r) + r += 1 + } s0 } ) From 72f6dbf7b0c8b271f5f9c762374422c69c8ab43d Mon Sep 17 00:00:00 2001 From: EugenCepoi Date: Mon, 31 Aug 2015 13:24:35 -0500 Subject: [PATCH 151/802] [SPARK-8730] Fixes - Deser objects containing a primitive class attribute Author: EugenCepoi Closes #7122 from EugenCepoi/master. --- .../spark/serializer/JavaSerializer.scala | 27 +++++++++++++++---- .../serializer/JavaSerializerSuite.scala | 18 +++++++++++++ 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 4a5274b46b7a0..b463a71d5bd7d 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -62,17 +62,34 @@ private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoa extends DeserializationStream { private val objIn = new ObjectInputStream(in) { - override def resolveClass(desc: ObjectStreamClass): Class[_] = { - // scalastyle:off classforname - Class.forName(desc.getName, false, loader) - // scalastyle:on classforname - } + override def resolveClass(desc: ObjectStreamClass): Class[_] = + try { + // scalastyle:off classforname + Class.forName(desc.getName, false, loader) + // scalastyle:on classforname + } catch { + case e: ClassNotFoundException => + JavaDeserializationStream.primitiveMappings.get(desc.getName).getOrElse(throw e) + } } def readObject[T: ClassTag](): T = objIn.readObject().asInstanceOf[T] def close() { objIn.close() } } +private object JavaDeserializationStream { + val primitiveMappings = Map[String, Class[_]]( + "boolean" -> classOf[Boolean], + "byte" -> classOf[Byte], + "char" -> classOf[Char], + "short" -> classOf[Short], + "int" -> classOf[Int], + "long" -> classOf[Long], + "float" -> classOf[Float], + "double" -> classOf[Double], + "void" -> classOf[Void] + ) +} private[spark] class JavaSerializerInstance( counterReset: Int, extraDebugInfo: Boolean, defaultClassLoader: ClassLoader) diff --git a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala index 329a2b6dad831..20f45670bc2ba 100644 --- a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala @@ -25,4 +25,22 @@ class JavaSerializerSuite extends SparkFunSuite { val instance = serializer.newInstance() instance.deserialize[JavaSerializer](instance.serialize(serializer)) } + + test("Deserialize object containing a primitive Class as attribute") { + val serializer = new JavaSerializer(new SparkConf()) + val instance = serializer.newInstance() + instance.deserialize[JavaSerializer](instance.serialize(new ContainsPrimitiveClass())) + } +} + +private class ContainsPrimitiveClass extends Serializable { + val intClass = classOf[Int] + val longClass = classOf[Long] + val shortClass = classOf[Short] + val charClass = classOf[Char] + val doubleClass = classOf[Double] + val floatClass = classOf[Float] + val booleanClass = classOf[Boolean] + val byteClass = classOf[Byte] + val voidClass = classOf[Void] } From 4a5fe091658b1d06f427e404a11a84fc84f953c5 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 31 Aug 2015 12:19:11 -0700 Subject: [PATCH 152/802] [SPARK-10369] [STREAMING] Don't remove ReceiverTrackingInfo when deregisterReceivering since we may reuse it later `deregisterReceiver` should not remove `ReceiverTrackingInfo`. Otherwise, it will throw `java.util.NoSuchElementException: key not found` when restarting it. Author: zsxwing Closes #8538 from zsxwing/SPARK-10369. --- .../streaming/scheduler/ReceiverTracker.scala | 4 +- .../scheduler/ReceiverTrackerSuite.scala | 51 +++++++++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 3d532a675db02..f86fd44b48719 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -291,7 +291,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false ReceiverTrackingInfo( streamId, ReceiverState.INACTIVE, None, None, None, None, Some(errorInfo)) } - receiverTrackingInfos -= streamId + receiverTrackingInfos(streamId) = newReceiverTrackingInfo listenerBus.post(StreamingListenerReceiverStopped(newReceiverTrackingInfo.toReceiverInfo)) val messageWithError = if (error != null && !error.isEmpty) { s"$message - $error" @@ -483,7 +483,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false context.reply(true) // Local messages case AllReceiverIds => - context.reply(receiverTrackingInfos.keys.toSeq) + context.reply(receiverTrackingInfos.filter(_._2.state != ReceiverState.INACTIVE).keys.toSeq) case StopAllReceivers => assert(isTrackerStopping || isTrackerStopped) stopReceivers() diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index dd292ba4dd949..45138b748ecab 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -60,6 +60,26 @@ class ReceiverTrackerSuite extends TestSuiteBase { } } } + + test("should restart receiver after stopping it") { + withStreamingContext(new StreamingContext(conf, Milliseconds(100))) { ssc => + @volatile var startTimes = 0 + ssc.addStreamingListener(new StreamingListener { + override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = { + startTimes += 1 + } + }) + val input = ssc.receiverStream(new StoppableReceiver) + val output = new TestOutputStream(input) + output.register() + ssc.start() + StoppableReceiver.shouldStop = true + eventually(timeout(10 seconds), interval(10 millis)) { + // The receiver is stopped once, so if it's restarted, it should be started twice. + assert(startTimes === 2) + } + } + } } /** An input DStream with for testing rate controlling */ @@ -132,3 +152,34 @@ private[streaming] object RateTestReceiver { def getActive(): Option[RateTestReceiver] = Option(activeReceiver) } + +/** + * A custom receiver that could be stopped via StoppableReceiver.shouldStop + */ +class StoppableReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) { + + var receivingThreadOption: Option[Thread] = None + + def onStart() { + val thread = new Thread() { + override def run() { + while (!StoppableReceiver.shouldStop) { + Thread.sleep(10) + } + StoppableReceiver.this.stop("stop") + } + } + thread.start() + } + + def onStop() { + StoppableReceiver.shouldStop = true + receivingThreadOption.foreach(_.join()) + // Reset it so as to restart it + StoppableReceiver.shouldStop = false + } +} + +object StoppableReceiver { + @volatile var shouldStop = false +} From a2d5c72091b1c602694dbca823a7b26f86b02864 Mon Sep 17 00:00:00 2001 From: sureshthalamati Date: Mon, 31 Aug 2015 12:39:58 -0700 Subject: [PATCH 153/802] [SPARK-10170] [SQL] Add DB2 JDBC dialect support. Data frame write to DB2 database is failing because by default JDBC data source implementation is generating a table schema with DB2 unsupported data types TEXT for String, and BIT1(1) for Boolean. This patch registers DB2 JDBC Dialect that maps String, Boolean to valid DB2 data types. Author: sureshthalamati Closes #8393 from sureshthalamati/db2_dialect_spark-10170. --- .../apache/spark/sql/jdbc/JdbcDialects.scala | 18 ++++++++++++++++++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 7 +++++++ 2 files changed, 25 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 8849fc2f1f0ef..c6d05c9b83b98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -125,6 +125,7 @@ object JdbcDialects { registerDialect(MySQLDialect) registerDialect(PostgresDialect) + registerDialect(DB2Dialect) /** * Fetch the JdbcDialect class corresponding to a given database url. @@ -222,3 +223,20 @@ case object MySQLDialect extends JdbcDialect { s"`$colName`" } } + +/** + * :: DeveloperApi :: + * Default DB2 dialect, mapping string/boolean on write to valid DB2 types. + * By default string, and boolean gets mapped to db2 invalid types TEXT, and BIT(1). + */ +@DeveloperApi +case object DB2Dialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = url.startsWith("jdbc:db2") + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case StringType => Some(JdbcType("CLOB", java.sql.Types.CLOB)) + case BooleanType => Some(JdbcType("CHAR(1)", java.sql.Types.CHAR)) + case _ => None + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 0edac0848c3bb..d8c9a08d84c61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -407,6 +407,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext test("Default jdbc dialect registration") { assert(JdbcDialects.get("jdbc:mysql://127.0.0.1/db") == MySQLDialect) assert(JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") == PostgresDialect) + assert(JdbcDialects.get("jdbc:db2://127.0.0.1/db") == DB2Dialect) assert(JdbcDialects.get("test.invalid") == NoopDialect) } @@ -443,4 +444,10 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(agg.getCatalystType(0, "", 1, null) === Some(LongType)) assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) } + + test("DB2Dialect type mapping") { + val db2Dialect = JdbcDialects.get("jdbc:db2://127.0.0.1/db") + assert(db2Dialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB") + assert(db2Dialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "CHAR(1)") + } } From 23e39cc7b1bb7f1087c4706234c9b5165a571357 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 31 Aug 2015 15:49:25 -0700 Subject: [PATCH 154/802] [SPARK-9954] [MLLIB] use first 128 nonzeros to compute Vector.hashCode This could help reduce hash collisions, e.g., in `RDD[Vector].repartition`. jkbradley Author: Xiangrui Meng Closes #8182 from mengxr/SPARK-9954. --- .../apache/spark/mllib/linalg/Vectors.scala | 38 ++++++++++--------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 06ebb15869909..3642e9286504f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -71,20 +71,22 @@ sealed trait Vector extends Serializable { } /** - * Returns a hash code value for the vector. The hash code is based on its size and its nonzeros - * in the first 16 entries, using a hash algorithm similar to [[java.util.Arrays.hashCode]]. + * Returns a hash code value for the vector. The hash code is based on its size and its first 128 + * nonzero entries, using a hash algorithm similar to [[java.util.Arrays.hashCode]]. */ override def hashCode(): Int = { // This is a reference implementation. It calls return in foreachActive, which is slow. // Subclasses should override it with optimized implementation. var result: Int = 31 + size + var nnz = 0 this.foreachActive { (index, value) => - if (index < 16) { + if (nnz < Vectors.MAX_HASH_NNZ) { // ignore explicit 0 for comparison between sparse and dense if (value != 0) { result = 31 * result + index val bits = java.lang.Double.doubleToLongBits(value) result = 31 * result + (bits ^ (bits >>> 32)).toInt + nnz += 1 } } else { return result @@ -536,6 +538,9 @@ object Vectors { } allEqual } + + /** Max number of nonzero entries used in computing hash code. */ + private[linalg] val MAX_HASH_NNZ = 128 } /** @@ -578,13 +583,15 @@ class DenseVector @Since("1.0.0") ( override def hashCode(): Int = { var result: Int = 31 + size var i = 0 - val end = math.min(values.length, 16) - while (i < end) { + val end = values.length + var nnz = 0 + while (i < end && nnz < Vectors.MAX_HASH_NNZ) { val v = values(i) if (v != 0.0) { result = 31 * result + i val bits = java.lang.Double.doubleToLongBits(values(i)) result = 31 * result + (bits ^ (bits >>> 32)).toInt + nnz += 1 } i += 1 } @@ -707,19 +714,16 @@ class SparseVector @Since("1.0.0") ( override def hashCode(): Int = { var result: Int = 31 + size val end = values.length - var continue = true var k = 0 - while ((k < end) & continue) { - val i = indices(k) - if (i < 16) { - val v = values(k) - if (v != 0.0) { - result = 31 * result + i - val bits = java.lang.Double.doubleToLongBits(v) - result = 31 * result + (bits ^ (bits >>> 32)).toInt - } - } else { - continue = false + var nnz = 0 + while (k < end && nnz < Vectors.MAX_HASH_NNZ) { + val v = values(k) + if (v != 0.0) { + val i = indices(k) + result = 31 * result + i + val bits = java.lang.Double.doubleToLongBits(v) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + nnz += 1 } k += 1 } From 5b3245d6dff65972fc39c73f90d5cbdf84d19129 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 31 Aug 2015 15:50:41 -0700 Subject: [PATCH 155/802] [SPARK-8472] [ML] [PySpark] Python API for DCT Add Python API for ml.feature.DCT. Author: Yanbo Liang Closes #8485 from yanboliang/spark-8472. --- python/pyspark/ml/feature.py | 65 +++++++++++++++++++++++++++++++++++- 1 file changed, 64 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 04b2b2ccc9e55..59300a607815b 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -26,7 +26,7 @@ from pyspark.mllib.common import inherit_doc from pyspark.mllib.linalg import _convert_to_vector -__all__ = ['Binarizer', 'Bucketizer', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', +__all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec', 'Word2VecModel', @@ -166,6 +166,69 @@ def getSplits(self): return self.getOrDefault(self.splits) +@inherit_doc +class DCT(JavaTransformer, HasInputCol, HasOutputCol): + """ + A feature transformer that takes the 1D discrete cosine transform + of a real vector. No zero padding is performed on the input vector. + It returns a real vector of the same length representing the DCT. + The return vector is scaled such that the transform matrix is + unitary (aka scaled DCT-II). + + More information on + `https://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II Wikipedia`. + + >>> from pyspark.mllib.linalg import Vectors + >>> df1 = sqlContext.createDataFrame([(Vectors.dense([5.0, 8.0, 6.0]),)], ["vec"]) + >>> dct = DCT(inverse=False, inputCol="vec", outputCol="resultVec") + >>> df2 = dct.transform(df1) + >>> df2.head().resultVec + DenseVector([10.969..., -0.707..., -2.041...]) + >>> df3 = DCT(inverse=True, inputCol="resultVec", outputCol="origVec").transform(df2) + >>> df3.head().origVec + DenseVector([5.0, 8.0, 6.0]) + """ + + # a placeholder to make it appear in the generated doc + inverse = Param(Params._dummy(), "inverse", "Set transformer to perform inverse DCT, " + + "default False.") + + @keyword_only + def __init__(self, inverse=False, inputCol=None, outputCol=None): + """ + __init__(self, inverse=False, inputCol=None, outputCol=None) + """ + super(DCT, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.DCT", self.uid) + self.inverse = Param(self, "inverse", "Set transformer to perform inverse DCT, " + + "default False.") + self._setDefault(inverse=False) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, inverse=False, inputCol=None, outputCol=None): + """ + setParams(self, inverse=False, inputCol=None, outputCol=None) + Sets params for this DCT. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setInverse(self, value): + """ + Sets the value of :py:attr:`inverse`. + """ + self._paramMap[self.inverse] = value + return self + + def getInverse(self): + """ + Gets the value of inverse or its default value. + """ + return self.getOrDefault(self.inverse) + + @inherit_doc class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol): """ From 540bdee93103a73736d282b95db6a8cda8f6a2b1 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 31 Aug 2015 15:55:22 -0700 Subject: [PATCH 156/802] [SPARK-10341] [SQL] fix memory starving in unsafe SMJ In SMJ, the first ExternalSorter could consume all the memory before spilling, then the second can not even acquire the first page. Before we have a better memory allocator, SMJ should call prepare() before call any compute() of it's children. cc rxin JoshRosen Author: Davies Liu Closes #8511 from davies/smj_memory. --- .../rdd/MapPartitionsWithPreparationRDD.scala | 21 +++++++++++++++++-- .../spark/rdd/ZippedPartitionsRDD.scala | 13 ++++++++++++ ...MapPartitionsWithPreparationRDDSuite.scala | 14 +++++++++---- 3 files changed, 42 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala index b475bd8d79f85..1f2213d0c4346 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala @@ -17,6 +17,7 @@ package org.apache.spark.rdd +import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import org.apache.spark.{Partition, Partitioner, TaskContext} @@ -38,12 +39,28 @@ private[spark] class MapPartitionsWithPreparationRDD[U: ClassTag, T: ClassTag, M override def getPartitions: Array[Partition] = firstParent[T].partitions + // In certain join operations, prepare can be called on the same partition multiple times. + // In this case, we need to ensure that each call to compute gets a separate prepare argument. + private[this] var preparedArguments: ArrayBuffer[M] = new ArrayBuffer[M] + + /** + * Prepare a partition for a single call to compute. + */ + def prepare(): Unit = { + preparedArguments += preparePartition() + } + /** * Prepare a partition before computing it from its parent. */ override def compute(partition: Partition, context: TaskContext): Iterator[U] = { - val preparedArgument = preparePartition() + val prepared = + if (preparedArguments.isEmpty) { + preparePartition() + } else { + preparedArguments.remove(0) + } val parentIterator = firstParent[T].iterator(partition, context) - executePartition(context, partition.index, preparedArgument, parentIterator) + executePartition(context, partition.index, prepared, parentIterator) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index 81f40ad33aa5d..b3c64394abc76 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -73,6 +73,16 @@ private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag]( super.clearDependencies() rdds = null } + + /** + * Call the prepare method of every parent that has one. + * This is needed for reserving execution memory in advance. + */ + protected def tryPrepareParents(): Unit = { + rdds.collect { + case rdd: MapPartitionsWithPreparationRDD[_, _, _] => rdd.prepare() + } + } } private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag]( @@ -84,6 +94,7 @@ private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag] extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { + tryPrepareParents() val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context)) } @@ -107,6 +118,7 @@ private[spark] class ZippedPartitionsRDD3 extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { + tryPrepareParents() val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context), @@ -134,6 +146,7 @@ private[spark] class ZippedPartitionsRDD4 extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { + tryPrepareParents() val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context), diff --git a/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala index c16930e7d6491..e281e817e493d 100644 --- a/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala @@ -46,11 +46,17 @@ class MapPartitionsWithPreparationRDDSuite extends SparkFunSuite with LocalSpark } // Verify that the numbers are pushed in the order expected - val result = { - new MapPartitionsWithPreparationRDD[Int, Int, Unit]( - parent, preparePartition, executePartition).collect() - } + val rdd = new MapPartitionsWithPreparationRDD[Int, Int, Unit]( + parent, preparePartition, executePartition) + val result = rdd.collect() assert(result === Array(10, 20, 30)) + + TestObject.things.clear() + // Zip two of these RDDs, both should be prepared before the parent is executed + val rdd2 = new MapPartitionsWithPreparationRDD[Int, Int, Unit]( + parent, preparePartition, executePartition) + val result2 = rdd.zipPartitions(rdd2)((a, b) => a).collect() + assert(result2 === Array(10, 10, 20, 30, 20, 30)) } } From fe16fd0b8b717f01151bc659ec3299dab091c97a Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 31 Aug 2015 16:06:38 -0700 Subject: [PATCH 157/802] [SPARK-10349] [ML] OneVsRest use 'when ... otherwise' not UDF to generate new label at binary reduction Currently OneVsRest use UDF to generate new binary label during training. Considering that [SPARK-7321](https://issues.apache.org/jira/browse/SPARK-7321) has been merged, we can use ```when ... otherwise``` which will be more efficiency. Author: Yanbo Liang Closes #8519 from yanboliang/spark-10349. --- .../org/apache/spark/ml/classification/OneVsRest.scala | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index c62e132f5d533..debc164bf2432 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -91,7 +91,6 @@ final class OneVsRestModel private[ml] ( // add an accumulator column to store predictions of all the models val accColName = "mbc$acc" + UUID.randomUUID().toString val initUDF = udf { () => Map[Int, Double]() } - val mapType = MapType(IntegerType, DoubleType, valueContainsNull = false) val newDataset = dataset.withColumn(accColName, initUDF()) // persist if underlying dataset is not persistent. @@ -195,16 +194,11 @@ final class OneVsRest(override val uid: String) // create k columns, one for each binary classifier. val models = Range(0, numClasses).par.map { index => - val labelUDF = udf { (label: Double) => - if (label.toInt == index) 1.0 else 0.0 - } - // generate new label metadata for the binary problem. - // TODO: use when ... otherwise after SPARK-7321 is merged val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata() val labelColName = "mc2b$" + index - val trainingDataset = - multiclassLabeled.withColumn(labelColName, labelUDF(col($(labelCol))), newLabelMeta) + val trainingDataset = multiclassLabeled.withColumn( + labelColName, when(col($(labelCol)) === index.toDouble, 1.0).otherwise(0.0), newLabelMeta) val classifier = getClassifier val paramMap = new ParamMap() paramMap.put(classifier.labelCol -> labelColName) From 52ea399e6ee37b7c44aae7709863e006fca88906 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 31 Aug 2015 16:11:27 -0700 Subject: [PATCH 158/802] [SPARK-10355] [ML] [PySpark] Add Python API for SQLTransformer Add Python API for SQLTransformer Author: Yanbo Liang Closes #8527 from yanboliang/spark-10355. --- python/pyspark/ml/feature.py | 57 ++++++++++++++++++++++++++++++++++-- 1 file changed, 54 insertions(+), 3 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 59300a607815b..0626281e200a1 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -28,9 +28,9 @@ __all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', - 'StandardScaler', 'StandardScalerModel', 'StringIndexer', 'StringIndexerModel', - 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec', 'Word2VecModel', - 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel'] + 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer', + 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec', + 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel'] @inherit_doc @@ -743,6 +743,57 @@ def getPattern(self): return self.getOrDefault(self.pattern) +@inherit_doc +class SQLTransformer(JavaTransformer): + """ + Implements the transforms which are defined by SQL statement. + Currently we only support SQL syntax like 'SELECT ... FROM __THIS__' + where '__THIS__' represents the underlying table of the input dataset. + + >>> df = sqlContext.createDataFrame([(0, 1.0, 3.0), (2, 2.0, 5.0)], ["id", "v1", "v2"]) + >>> sqlTrans = SQLTransformer( + ... statement="SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") + >>> sqlTrans.transform(df).head() + Row(id=0, v1=1.0, v2=3.0, v3=4.0, v4=3.0) + """ + + # a placeholder to make it appear in the generated doc + statement = Param(Params._dummy(), "statement", "SQL statement") + + @keyword_only + def __init__(self, statement=None): + """ + __init__(self, statement=None) + """ + super(SQLTransformer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.SQLTransformer", self.uid) + self.statement = Param(self, "statement", "SQL statement") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, statement=None): + """ + setParams(self, statement=None) + Sets params for this SQLTransformer. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setStatement(self, value): + """ + Sets the value of :py:attr:`statement`. + """ + self._paramMap[self.statement] = value + return self + + def getStatement(self): + """ + Gets the value of statement or its default value. + """ + return self.getOrDefault(self.statement) + + @inherit_doc class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol): """ From d65656c455d19b83c6412571873586b458aa355e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 31 Aug 2015 18:09:24 -0700 Subject: [PATCH 159/802] [SPARK-10378][SQL][Test] Remove HashJoinCompatibilitySuite. They don't bring much value since we now have better unit test coverage for hash joins. This will also help reduce the test time. Author: Reynold Xin Closes #8542 from rxin/SPARK-10378. --- .../HashJoinCompatibilitySuite.scala | 169 ------------------ 1 file changed, 169 deletions(-) delete mode 100644 sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala deleted file mode 100644 index 1a5ba20404c4e..0000000000000 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala +++ /dev/null @@ -1,169 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.execution - -import java.io.File - -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.hive.test.TestHive - -/** - * Runs the test cases that are included in the hive distribution with hash joins. - */ -class HashJoinCompatibilitySuite extends HiveCompatibilitySuite { - override def beforeAll() { - super.beforeAll() - TestHive.setConf(SQLConf.SORTMERGE_JOIN, false) - } - - override def afterAll() { - TestHive.setConf(SQLConf.SORTMERGE_JOIN, true) - super.afterAll() - } - - override def whiteList = Seq( - "auto_join0", - "auto_join1", - "auto_join10", - "auto_join11", - "auto_join12", - "auto_join13", - "auto_join14", - "auto_join14_hadoop20", - "auto_join15", - "auto_join17", - "auto_join18", - "auto_join19", - "auto_join2", - "auto_join20", - "auto_join21", - "auto_join22", - "auto_join23", - "auto_join24", - "auto_join25", - "auto_join26", - "auto_join27", - "auto_join28", - "auto_join3", - "auto_join30", - "auto_join31", - "auto_join32", - "auto_join4", - "auto_join5", - "auto_join6", - "auto_join7", - "auto_join8", - "auto_join9", - "auto_join_filters", - "auto_join_nulls", - "auto_join_reordering_values", - "auto_smb_mapjoin_14", - "auto_sortmerge_join_1", - "auto_sortmerge_join_10", - "auto_sortmerge_join_11", - "auto_sortmerge_join_12", - "auto_sortmerge_join_13", - "auto_sortmerge_join_14", - "auto_sortmerge_join_15", - "auto_sortmerge_join_16", - "auto_sortmerge_join_2", - "auto_sortmerge_join_3", - "auto_sortmerge_join_4", - "auto_sortmerge_join_5", - "auto_sortmerge_join_6", - "auto_sortmerge_join_7", - "auto_sortmerge_join_8", - "auto_sortmerge_join_9", - "correlationoptimizer1", - "correlationoptimizer10", - "correlationoptimizer11", - "correlationoptimizer13", - "correlationoptimizer14", - "correlationoptimizer15", - "correlationoptimizer2", - "correlationoptimizer3", - "correlationoptimizer4", - "correlationoptimizer6", - "correlationoptimizer7", - "correlationoptimizer8", - "correlationoptimizer9", - "join0", - "join1", - "join10", - "join11", - "join12", - "join13", - "join14", - "join14_hadoop20", - "join15", - "join16", - "join17", - "join18", - "join19", - "join2", - "join20", - "join21", - "join22", - "join23", - "join24", - "join25", - "join26", - "join27", - "join28", - "join29", - "join3", - "join30", - "join31", - "join32", - "join32_lessSize", - "join33", - "join34", - "join35", - "join36", - "join37", - "join38", - "join39", - "join4", - "join40", - "join41", - "join5", - "join6", - "join7", - "join8", - "join9", - "join_1to1", - "join_array", - "join_casesensitive", - "join_empty", - "join_filters", - "join_hive_626", - "join_map_ppr", - "join_nulls", - "join_nullsafe", - "join_rc", - "join_reorder2", - "join_reorder3", - "join_reorder4", - "join_star" - ) - - // Only run those query tests in the realWhileList (do not try other ignored query files). - override def testCases: Seq[(String, File)] = super.testCases.filter { - case (name, _) => realWhiteList.contains(name) - } -} From 391e6be0ae883f3ea0fab79463eb8b618af79afb Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 1 Sep 2015 16:52:59 +0800 Subject: [PATCH 160/802] [SPARK-10301] [SQL] Fixes schema merging for nested structs This PR can be quite challenging to review. I'm trying to give a detailed description of the problem as well as its solution here. When reading Parquet files, we need to specify a potentially nested Parquet schema (of type `MessageType`) as requested schema for column pruning. This Parquet schema is translated from a Catalyst schema (of type `StructType`), which is generated by the query planner and represents all requested columns. However, this translation can be fairly complicated because of several reasons: 1. Requested schema must conform to the real schema of the physical file to be read. This means we have to tailor the actual file schema of every individual physical Parquet file to be read according to the given Catalyst schema. Fortunately we are already doing this in Spark 1.5 by pushing request schema conversion to executor side in PR #7231. 1. Support for schema merging. A single Parquet dataset may consist of multiple physical Parquet files come with different but compatible schemas. This means we may request for a column path that doesn't exist in a physical Parquet file. All requested column paths can be nested. For example, for a Parquet file schema ``` message root { required group f0 { required group f00 { required int32 f000; required binary f001 (UTF8); } } } ``` we may request for column paths defined in the following schema: ``` message root { required group f0 { required group f00 { required binary f001 (UTF8); required float f002; } } optional double f1; } ``` Notice that we pruned column path `f0.f00.f000`, but added `f0.f00.f002` and `f1`. The good news is that Parquet handles non-existing column paths properly and always returns null for them. 1. The map from `StructType` to `MessageType` is a one-to-many map. This is the most unfortunate part. Due to historical reasons (dark histories!), schemas of Parquet files generated by different libraries have different "flavors". For example, to handle a schema with a single non-nullable column, whose type is an array of non-nullable integers, parquet-protobuf generates the following Parquet schema: ``` message m0 { repeated int32 f; } ``` while parquet-avro generates another version: ``` message m1 { required group f (LIST) { repeated int32 array; } } ``` and parquet-thrift spills this: ``` message m1 { required group f (LIST) { repeated int32 f_tuple; } } ``` All of them can be mapped to the following _unique_ Catalyst schema: ``` StructType( StructField( "f", ArrayType(IntegerType, containsNull = false), nullable = false)) ``` This greatly complicates Parquet requested schema construction, since the path of a given column varies in different cases. To read the array elements from files with the above schemas, we must use `f` for `m0`, `f.array` for `m1`, and `f.f_tuple` for `m2`. In earlier Spark versions, we didn't try to fix this issue properly. Spark 1.4 and prior versions simply translate the Catalyst schema in a way more or less compatible with parquet-hive and parquet-avro, but is broken in many other cases. Earlier revisions of Spark 1.5 only try to tailor the Parquet file schema at the first level, and ignore nested ones. This caused [SPARK-10301] [spark-10301] as well as [SPARK-10005] [spark-10005]. In PR #8228, I tried to avoid the hard part of the problem and made a minimum change in `CatalystRowConverter` to fix SPARK-10005. However, when taking SPARK-10301 into consideration, keeping hacking `CatalystRowConverter` doesn't seem to be a good idea. So this PR is an attempt to fix the problem in a proper way. For a given physical Parquet file with schema `ps` and a compatible Catalyst requested schema `cs`, we use the following algorithm to tailor `ps` to get the result Parquet requested schema `ps'`: For a leaf column path `c` in `cs`: - if `c` exists in `cs` and a corresponding Parquet column path `c'` can be found in `ps`, `c'` should be included in `ps'`; - otherwise, we convert `c` to a Parquet column path `c"` using `CatalystSchemaConverter`, and include `c"` in `ps'`; - no other column paths should exist in `ps'`. Then comes the most tedious part: > Given `cs`, `ps`, and `c`, how to locate `c'` in `ps`? Unfortunately, there's no quick answer, and we have to enumerate all possible structures defined in parquet-format spec. They are: 1. the standard structure of nested types, and 1. cases defined in all backwards-compatibility rules for `LIST` and `MAP`. The core part of this PR is `CatalystReadSupport.clipParquetType()`, which tailors a given Parquet file schema according to a requested schema in its Catalyst form. Backwards-compatibility rules of `LIST` and `MAP` are covered in `clipParquetListType()` and `clipParquetMapType()` respectively. The column path selection algorithm is implemented in `clipParquetGroupFields()`. With this PR, we no longer need to do schema tailoring in `CatalystReadSupport` and `CatalystRowConverter`. Another benefit is that, now we can also read Parquet datasets consist of files with different physical Parquet schema but share the same logical schema, for example, files generated by different Parquet libraries. This situation is illustrated by [this test case] [test-case]. [spark-10301]: https://issues.apache.org/jira/browse/SPARK-10301 [spark-10005]: https://issues.apache.org/jira/browse/SPARK-10005 [test-case]: https://github.com/liancheng/spark/commit/38644d8a45175cbdf20d2ace021c2c2544a50ab3#diff-a9b98e28ce3ae30641829dffd1173be2R26 Author: Cheng Lian Closes #8509 from liancheng/spark-10301/fix-parquet-requested-schema. --- .../parquet/CatalystReadSupport.scala | 235 +++++++++---- .../parquet/CatalystRowConverter.scala | 51 +-- .../parquet/CatalystSchemaConverter.scala | 14 +- .../ParquetAvroCompatibilitySuite.scala | 1 + .../ParquetInteroperabilitySuite.scala | 90 +++++ .../parquet/ParquetQuerySuite.scala | 77 +++++ .../parquet/ParquetSchemaSuite.scala | 310 ++++++++++++++++++ 7 files changed, 653 insertions(+), 125 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala index 0a6bb44445f6e..dc4ff06df6f22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala @@ -19,17 +19,18 @@ package org.apache.spark.sql.execution.datasources.parquet import java.util.{Map => JMap} -import scala.collection.JavaConverters._ +import scala.collection.JavaConverters.{collectionAsScalaIterableConverter, mapAsJavaMapConverter, mapAsScalaMapConverter} import org.apache.hadoop.conf.Configuration import org.apache.parquet.hadoop.api.ReadSupport.ReadContext import org.apache.parquet.hadoop.api.{InitContext, ReadSupport} import org.apache.parquet.io.api.RecordMaterializer -import org.apache.parquet.schema.MessageType +import org.apache.parquet.schema.Type.Repetition +import org.apache.parquet.schema._ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with Logging { // Called after `init()` when initializing Parquet record reader. @@ -81,70 +82,10 @@ private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with // `StructType` containing all requested columns. val maybeRequestedSchema = Option(conf.get(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA)) - // Below we construct a Parquet schema containing all requested columns. This schema tells - // Parquet which columns to read. - // - // If `maybeRequestedSchema` is defined, we assemble an equivalent Parquet schema. Otherwise, - // we have to fallback to the full file schema which contains all columns in the file. - // Obviously this may waste IO bandwidth since it may read more columns than requested. - // - // Two things to note: - // - // 1. It's possible that some requested columns don't exist in the target Parquet file. For - // example, in the case of schema merging, the globally merged schema may contain extra - // columns gathered from other Parquet files. These columns will be simply filled with nulls - // when actually reading the target Parquet file. - // - // 2. When `maybeRequestedSchema` is available, we can't simply convert the Catalyst schema to - // Parquet schema using `CatalystSchemaConverter`, because the mapping is not unique due to - // non-standard behaviors of some Parquet libraries/tools. For example, a Parquet file - // containing a single integer array field `f1` may have the following legacy 2-level - // structure: - // - // message root { - // optional group f1 (LIST) { - // required INT32 element; - // } - // } - // - // while `CatalystSchemaConverter` may generate a standard 3-level structure: - // - // message root { - // optional group f1 (LIST) { - // repeated group list { - // required INT32 element; - // } - // } - // } - // - // Apparently, we can't use the 2nd schema to read the target Parquet file as they have - // different physical structures. val parquetRequestedSchema = maybeRequestedSchema.fold(context.getFileSchema) { schemaString => - val toParquet = new CatalystSchemaConverter(conf) - val fileSchema = context.getFileSchema.asGroupType() - val fileFieldNames = fileSchema.getFields.asScala.map(_.getName).toSet - - StructType - // Deserializes the Catalyst schema of requested columns - .fromString(schemaString) - .map { field => - if (fileFieldNames.contains(field.name)) { - // If the field exists in the target Parquet file, extracts the field type from the - // full file schema and makes a single-field Parquet schema - new MessageType("root", fileSchema.getType(field.name)) - } else { - // Otherwise, just resorts to `CatalystSchemaConverter` - toParquet.convert(StructType(Array(field))) - } - } - // Merges all single-field Parquet schemas to form a complete schema for all requested - // columns. Note that it's possible that no columns are requested at all (e.g., count - // some partition column of a partitioned Parquet table). That's why `fold` is used here - // and always fallback to an empty Parquet schema. - .fold(new MessageType("root")) { - _ union _ - } + val catalystRequestedSchema = StructType.fromString(schemaString) + CatalystReadSupport.clipParquetSchema(context.getFileSchema, catalystRequestedSchema) } val metadata = @@ -160,4 +101,168 @@ private[parquet] object CatalystReadSupport { val SPARK_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" val SPARK_METADATA_KEY = "org.apache.spark.sql.parquet.row.metadata" + + /** + * Tailors `parquetSchema` according to `catalystSchema` by removing column paths don't exist + * in `catalystSchema`, and adding those only exist in `catalystSchema`. + */ + def clipParquetSchema(parquetSchema: MessageType, catalystSchema: StructType): MessageType = { + val clippedParquetFields = clipParquetGroupFields(parquetSchema.asGroupType(), catalystSchema) + Types.buildMessage().addFields(clippedParquetFields: _*).named("root") + } + + private def clipParquetType(parquetType: Type, catalystType: DataType): Type = { + catalystType match { + case t: ArrayType if !isPrimitiveCatalystType(t.elementType) => + // Only clips array types with nested type as element type. + clipParquetListType(parquetType.asGroupType(), t.elementType) + + case t: MapType if !isPrimitiveCatalystType(t.valueType) => + // Only clips map types with nested type as value type. + clipParquetMapType(parquetType.asGroupType(), t.keyType, t.valueType) + + case t: StructType => + clipParquetGroup(parquetType.asGroupType(), t) + + case _ => + parquetType + } + } + + /** + * Whether a Catalyst [[DataType]] is primitive. Primitive [[DataType]] is not equivalent to + * [[AtomicType]]. For example, [[CalendarIntervalType]] is primitive, but it's not an + * [[AtomicType]]. + */ + private def isPrimitiveCatalystType(dataType: DataType): Boolean = { + dataType match { + case _: ArrayType | _: MapType | _: StructType => false + case _ => true + } + } + + /** + * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[ArrayType]]. The element type + * of the [[ArrayType]] should also be a nested type, namely an [[ArrayType]], a [[MapType]], or a + * [[StructType]]. + */ + private def clipParquetListType(parquetList: GroupType, elementType: DataType): Type = { + // Precondition of this method, should only be called for lists with nested element types. + assert(!isPrimitiveCatalystType(elementType)) + + // Unannotated repeated group should be interpreted as required list of required element, so + // list element type is just the group itself. Clip it. + if (parquetList.getOriginalType == null && parquetList.isRepetition(Repetition.REPEATED)) { + clipParquetType(parquetList, elementType) + } else { + assert( + parquetList.getOriginalType == OriginalType.LIST, + "Invalid Parquet schema. " + + "Original type of annotated Parquet lists must be LIST: " + + parquetList.toString) + + assert( + parquetList.getFieldCount == 1 && parquetList.getType(0).isRepetition(Repetition.REPEATED), + "Invalid Parquet schema. " + + "LIST-annotated group should only have exactly one repeated field: " + + parquetList) + + // Precondition of this method, should only be called for lists with nested element types. + assert(!parquetList.getType(0).isPrimitive) + + val repeatedGroup = parquetList.getType(0).asGroupType() + + // If the repeated field is a group with multiple fields, or the repeated field is a group + // with one field and is named either "array" or uses the LIST-annotated group's name with + // "_tuple" appended then the repeated type is the element type and elements are required. + // Build a new LIST-annotated group with clipped `repeatedGroup` as element type and the + // only field. + if ( + repeatedGroup.getFieldCount > 1 || + repeatedGroup.getName == "array" || + repeatedGroup.getName == parquetList.getName + "_tuple" + ) { + Types + .buildGroup(parquetList.getRepetition) + .as(OriginalType.LIST) + .addField(clipParquetType(repeatedGroup, elementType)) + .named(parquetList.getName) + } else { + // Otherwise, the repeated field's type is the element type with the repeated field's + // repetition. + Types + .buildGroup(parquetList.getRepetition) + .as(OriginalType.LIST) + .addField( + Types + .repeatedGroup() + .addField(clipParquetType(repeatedGroup.getType(0), elementType)) + .named(repeatedGroup.getName)) + .named(parquetList.getName) + } + } + } + + /** + * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[MapType]]. The value type + * of the [[MapType]] should also be a nested type, namely an [[ArrayType]], a [[MapType]], or a + * [[StructType]]. Note that key type of any [[MapType]] is always a primitive type. + */ + private def clipParquetMapType( + parquetMap: GroupType, keyType: DataType, valueType: DataType): GroupType = { + // Precondition of this method, should only be called for maps with nested value types. + assert(!isPrimitiveCatalystType(valueType)) + + val repeatedGroup = parquetMap.getType(0).asGroupType() + val parquetKeyType = repeatedGroup.getType(0) + val parquetValueType = repeatedGroup.getType(1) + + val clippedRepeatedGroup = + Types + .repeatedGroup() + .as(repeatedGroup.getOriginalType) + .addField(parquetKeyType) + .addField(clipParquetType(parquetValueType, valueType)) + .named(repeatedGroup.getName) + + Types + .buildGroup(parquetMap.getRepetition) + .as(parquetMap.getOriginalType) + .addField(clippedRepeatedGroup) + .named(parquetMap.getName) + } + + /** + * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[StructType]]. + * + * @return A clipped [[GroupType]], which has at least one field. + * @note Parquet doesn't allow creating empty [[GroupType]] instances except for empty + * [[MessageType]]. Because it's legal to construct an empty requested schema for column + * pruning. + */ + private def clipParquetGroup(parquetRecord: GroupType, structType: StructType): GroupType = { + val clippedParquetFields = clipParquetGroupFields(parquetRecord, structType) + Types + .buildGroup(parquetRecord.getRepetition) + .as(parquetRecord.getOriginalType) + .addFields(clippedParquetFields: _*) + .named(parquetRecord.getName) + } + + /** + * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[StructType]]. + * + * @return A list of clipped [[GroupType]] fields, which can be empty. + */ + private def clipParquetGroupFields( + parquetRecord: GroupType, structType: StructType): Seq[Type] = { + val parquetFieldMap = parquetRecord.getFields.asScala.map(f => f.getName -> f).toMap + val toParquet = new CatalystSchemaConverter(followParquetFormatSpec = true) + structType.map { f => + parquetFieldMap + .get(f.name) + .map(clipParquetType(_, f.dataType)) + .getOrElse(toParquet.convertField(f)) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index fe13dfbbed385..f17e794b76650 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -113,31 +113,6 @@ private[parquet] class CatalystPrimitiveConverter(val updater: ParentContainerUp * When used as a root converter, [[NoopUpdater]] should be used since root converters don't have * any "parent" container. * - * @note Constructor argument [[parquetType]] refers to requested fields of the actual schema of the - * Parquet file being read, while constructor argument [[catalystType]] refers to requested - * fields of the global schema. The key difference is that, in case of schema merging, - * [[parquetType]] can be a subset of [[catalystType]]. For example, it's possible to have - * the following [[catalystType]]: - * {{{ - * new StructType() - * .add("f1", IntegerType, nullable = false) - * .add("f2", StringType, nullable = true) - * .add("f3", new StructType() - * .add("f31", DoubleType, nullable = false) - * .add("f32", IntegerType, nullable = true) - * .add("f33", StringType, nullable = true), nullable = false) - * }}} - * and the following [[parquetType]] (`f2` and `f32` are missing): - * {{{ - * message root { - * required int32 f1; - * required group f3 { - * required double f31; - * optional binary f33 (utf8); - * } - * } - * }}} - * * @param parquetType Parquet schema of Parquet records * @param catalystType Spark SQL schema that corresponds to the Parquet record type * @param updater An updater which propagates converted field values to the parent container @@ -179,31 +154,7 @@ private[parquet] class CatalystRowConverter( // Converters for each field. private val fieldConverters: Array[Converter with HasParentContainerUpdater] = { - // In case of schema merging, `parquetType` can be a subset of `catalystType`. We need to pad - // those missing fields and create converters for them, although values of these fields are - // always null. - val paddedParquetFields = { - val parquetFields = parquetType.getFields.asScala - val parquetFieldNames = parquetFields.map(_.getName).toSet - val missingFields = catalystType.filterNot(f => parquetFieldNames.contains(f.name)) - - // We don't need to worry about feature flag arguments like `assumeBinaryIsString` when - // creating the schema converter here, since values of missing fields are always null. - val toParquet = new CatalystSchemaConverter() - - (parquetFields ++ missingFields.map(toParquet.convertField)).sortBy { f => - catalystType.indexWhere(_.name == f.getName) - } - } - - if (paddedParquetFields.length != catalystType.length) { - throw new UnsupportedOperationException( - "A Parquet file's schema has different number of fields with the table schema. " + - "Please enable schema merging by setting \"mergeSchema\" to true when load " + - "a Parquet dataset or set spark.sql.parquet.mergeSchema to true in SQLConf.") - } - - paddedParquetFields.zip(catalystType).zipWithIndex.map { + parquetType.getFields.asScala.zip(catalystType).zipWithIndex.map { case ((parquetFieldType, catalystField), ordinal) => // Converted field value should be set to the `ordinal`-th cell of `currentRow` newConverter(parquetFieldType, catalystField.dataType, new RowUpdater(currentRow, ordinal)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index be6c0545f5a0a..a21ab1dbb25d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -55,16 +55,10 @@ import org.apache.spark.sql.{AnalysisException, SQLConf} * to old style non-standard behaviors. */ private[parquet] class CatalystSchemaConverter( - private val assumeBinaryIsString: Boolean, - private val assumeInt96IsTimestamp: Boolean, - private val followParquetFormatSpec: Boolean) { - - // Only used when constructing converter for converting Spark SQL schema to Parquet schema, in - // which case `assumeInt96IsTimestamp` and `assumeBinaryIsString` are irrelevant. - def this() = this( - assumeBinaryIsString = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get, - assumeInt96IsTimestamp = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, - followParquetFormatSpec = SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.defaultValue.get) + assumeBinaryIsString: Boolean = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get, + assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, + followParquetFormatSpec: Boolean = SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.defaultValue.get +) { def this(conf: SQLConf) = this( assumeBinaryIsString = conf.isParquetBinaryAsString, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala index bd7cf8c10abef..36b929ee1f409 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet +import java.io.File import java.nio.ByteBuffer import java.util.{List => JList, Map => JMap} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala new file mode 100644 index 0000000000000..83b65fb419ed3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import org.apache.spark.sql.Row +import org.apache.spark.sql.test.SharedSQLContext + +class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedSQLContext { + test("parquet files with different physical schemas but share the same logical schema") { + import ParquetCompatibilityTest._ + + // This test case writes two Parquet files, both representing the following Catalyst schema + // + // StructType( + // StructField( + // "f", + // ArrayType(IntegerType, containsNull = false), + // nullable = false)) + // + // The first Parquet file comes with parquet-avro style 2-level LIST-annotated group, while the + // other one comes with parquet-protobuf style 1-level unannotated primitive field. + withTempDir { dir => + val avroStylePath = new File(dir, "avro-style").getCanonicalPath + val protobufStylePath = new File(dir, "protobuf-style").getCanonicalPath + + val avroStyleSchema = + """message avro_style { + | required group f (LIST) { + | repeated int32 array; + | } + |} + """.stripMargin + + writeDirect(avroStylePath, avroStyleSchema, { rc => + rc.message { + rc.field("f", 0) { + rc.group { + rc.field("array", 0) { + rc.addInteger(0) + rc.addInteger(1) + } + } + } + } + }) + + logParquetSchema(avroStylePath) + + val protobufStyleSchema = + """message protobuf_style { + | repeated int32 f; + |} + """.stripMargin + + writeDirect(protobufStylePath, protobufStyleSchema, { rc => + rc.message { + rc.field("f", 0) { + rc.addInteger(2) + rc.addInteger(3) + } + } + }) + + logParquetSchema(protobufStylePath) + + checkAnswer( + sqlContext.read.parquet(dir.getCanonicalPath), + Seq( + Row(Seq(0, 1)), + Row(Seq(2, 3)))) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index b7b70c2bbbd5c..a379523d67f80 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -229,4 +229,81 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } } + + test("SPARK-10301 Clipping nested structs in requested schema") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', id, 'b', id) AS s") + .coalesce(1) + + df.write.mode("append").parquet(path) + + val userDefinedSchema = new StructType() + .add("s", new StructType().add("a", LongType, nullable = true), nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(0))) + } + + withTempPath { dir => + val path = dir.getCanonicalPath + + val df1 = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', id, 'b', id) AS s") + .coalesce(1) + + val df2 = sqlContext + .range(1, 2) + .selectExpr("NAMED_STRUCT('b', id, 'c', id) AS s") + .coalesce(1) + + df1.write.parquet(path) + df2.write.mode(SaveMode.Append).parquet(path) + + val userDefinedSchema = new StructType() + .add("s", + new StructType() + .add("a", LongType, nullable = true) + .add("c", LongType, nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Seq( + Row(Row(0, null)), + Row(Row(null, 1)))) + } + + withTempPath { dir => + val path = dir.getCanonicalPath + + val df = sqlContext + .range(1) + .selectExpr("NAMED_STRUCT('a', ARRAY(NAMED_STRUCT('b', id, 'c', id))) AS s") + .coalesce(1) + + df.write.parquet(path) + + val userDefinedSchema = new StructType() + .add("s", + new StructType() + .add( + "a", + ArrayType( + new StructType() + .add("b", LongType, nullable = true) + .add("d", StringType, nullable = true), + containsNull = true), + nullable = true), + nullable = true) + + checkAnswer( + sqlContext.read.schema(userDefinedSchema).parquet(path), + Row(Row(Seq(Row(0, null))))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 9dcbc1a047bea..28c59a4abdd76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.parquet.schema.MessageTypeParser +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -941,4 +942,313 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | optional fixed_len_byte_array(8) f1 (DECIMAL(18, 3)); |} """.stripMargin) + + private def testSchemaClipping( + testName: String, + parquetSchema: String, + catalystSchema: StructType, + expectedSchema: String): Unit = { + test(s"Clipping - $testName") { + val expected = MessageTypeParser.parseMessageType(expectedSchema) + val actual = CatalystReadSupport.clipParquetSchema( + MessageTypeParser.parseMessageType(parquetSchema), catalystSchema) + + try { + expected.checkContains(actual) + actual.checkContains(expected) + } catch { case cause: Throwable => + fail( + s"""Expected clipped schema: + |$expected + |Actual clipped schema: + |$actual + """.stripMargin, + cause) + } + } + } + + testSchemaClipping( + "simple nested struct", + + parquetSchema = + """message root { + | required group f0 { + | optional int32 f00; + | optional int32 f01; + | } + |} + """.stripMargin, + + catalystSchema = { + val f0Type = new StructType().add("f00", IntegerType, nullable = true) + new StructType() + .add("f0", f0Type, nullable = false) + .add("f1", IntegerType, nullable = true) + }, + + expectedSchema = + """message root { + | required group f0 { + | optional int32 f00; + | } + | optional int32 f1; + |} + """.stripMargin) + + testSchemaClipping( + "parquet-protobuf style array", + + parquetSchema = + """message root { + | required group f0 { + | repeated binary f00 (UTF8); + | repeated group f01 { + | optional int32 f010; + | optional double f011; + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f11Type = new StructType().add("f011", DoubleType, nullable = true) + val f01Type = ArrayType(StringType, containsNull = false) + val f0Type = new StructType() + .add("f00", f01Type, nullable = false) + .add("f01", f11Type, nullable = false) + val f1Type = ArrayType(IntegerType, containsNull = true) + new StructType() + .add("f0", f0Type, nullable = false) + .add("f1", f1Type, nullable = true) + }, + + expectedSchema = + """message root { + | required group f0 { + | repeated binary f00 (UTF8); + | repeated group f01 { + | optional double f011; + | } + | } + | + | optional group f1 (LIST) { + | repeated group list { + | optional int32 element; + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "parquet-thrift style array", + + parquetSchema = + """message root { + | required group f0 { + | optional group f00 { + | repeated binary f00_tuple (UTF8); + | } + | + | optional group f01 (LIST) { + | repeated group f01_tuple { + | optional int32 f010; + | optional double f011; + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f11ElementType = new StructType() + .add("f011", DoubleType, nullable = true) + .add("f012", LongType, nullable = true) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = false), nullable = false) + .add("f01", ArrayType(f11ElementType, containsNull = false), nullable = false) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 { + | optional group f00 { + | repeated binary f00_tuple (UTF8); + | } + | + | optional group f01 (LIST) { + | repeated group f01_tuple { + | optional double f011; + | optional int64 f012; + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "parquet-avro style array", + + parquetSchema = + """message root { + | required group f0 { + | optional group f00 { + | repeated binary array (UTF8); + | } + | + | optional group f01 (LIST) { + | repeated group array { + | optional int32 f010; + | optional double f011; + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f11ElementType = new StructType() + .add("f011", DoubleType, nullable = true) + .add("f012", LongType, nullable = true) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = false), nullable = false) + .add("f01", ArrayType(f11ElementType, containsNull = false), nullable = false) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + """message root { + | required group f0 { + | optional group f00 { + | repeated binary array (UTF8); + | } + | + | optional group f01 (LIST) { + | repeated group array { + | optional double f011; + | optional int64 f012; + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "parquet-hive style array", + + parquetSchema = + """message root { + | optional group f0 { + | optional group f00 (LIST) { + | repeated group bag { + | optional binary array_element; + | } + | } + | + | optional group f01 (LIST) { + | repeated group bag { + | optional group array_element { + | optional int32 f010; + | optional double f011; + | } + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f01ElementType = new StructType() + .add("f011", DoubleType, nullable = true) + .add("f012", LongType, nullable = true) + + val f0Type = new StructType() + .add("f00", ArrayType(StringType, containsNull = true), nullable = true) + .add("f01", ArrayType(f01ElementType, containsNull = true), nullable = true) + + new StructType().add("f0", f0Type, nullable = true) + }, + + expectedSchema = + """message root { + | optional group f0 { + | optional group f00 (LIST) { + | repeated group bag { + | optional binary array_element; + | } + | } + | + | optional group f01 (LIST) { + | repeated group bag { + | optional group array_element { + | optional double f011; + | optional int64 f012; + | } + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "2-level list of required struct", + + parquetSchema = + s"""message root { + | required group f0 { + | required group f00 (LIST) { + | repeated group element { + | required int32 f000; + | optional int64 f001; + | } + | } + | } + |} + """.stripMargin, + + catalystSchema = { + val f00ElementType = + new StructType() + .add("f001", LongType, nullable = true) + .add("f002", DoubleType, nullable = false) + + val f00Type = ArrayType(f00ElementType, containsNull = false) + val f0Type = new StructType().add("f00", f00Type, nullable = false) + + new StructType().add("f0", f0Type, nullable = false) + }, + + expectedSchema = + s"""message root { + | required group f0 { + | required group f00 (LIST) { + | repeated group element { + | optional int64 f001; + | required double f002; + | } + | } + | } + |} + """.stripMargin) + + testSchemaClipping( + "empty requested schema", + + parquetSchema = + """message root { + | required group f0 { + | required int32 f00; + | required int64 f01; + | } + |} + """.stripMargin, + + catalystSchema = new StructType(), + + expectedSchema = "message root {}") } From e6e483cc4de740c46398385b03ffe0e662edae39 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 1 Sep 2015 10:48:57 -0700 Subject: [PATCH 161/802] [SPARK-9679] [ML] [PYSPARK] Add Python API for Stop Words Remover Add a python API for the Stop Words Remover. Author: Holden Karau Closes #8118 from holdenk/SPARK-9679-python-StopWordsRemover. --- .../spark/ml/feature/StopWordsRemover.scala | 6 +- .../ml/feature/StopWordsRemoverSuite.scala | 2 +- python/pyspark/ml/feature.py | 73 ++++++++++++++++++- python/pyspark/ml/tests.py | 20 ++++- 4 files changed, 93 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 5d77ea08db657..7da430c7d16df 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -29,14 +29,14 @@ import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructTyp /** * stop words list */ -private object StopWords { +private[spark] object StopWords { /** * Use the same default stopwords list as scikit-learn. * The original list can be found from "Glasgow Information Retrieval Group" * [[http://ir.dcs.gla.ac.uk/resources/linguistic_utils/stop_words]] */ - val EnglishStopWords = Array( "a", "about", "above", "across", "after", "afterwards", "again", + val English = Array( "a", "about", "above", "across", "after", "afterwards", "again", "against", "all", "almost", "alone", "along", "already", "also", "although", "always", "am", "among", "amongst", "amoungst", "amount", "an", "and", "another", "any", "anyhow", "anyone", "anything", "anyway", "anywhere", "are", @@ -121,7 +121,7 @@ class StopWordsRemover(override val uid: String) /** @group getParam */ def getCaseSensitive: Boolean = $(caseSensitive) - setDefault(stopWords -> StopWords.EnglishStopWords, caseSensitive -> false) + setDefault(stopWords -> StopWords.English, caseSensitive -> false) override def transform(dataset: DataFrame): DataFrame = { val outputSchema = transformSchema(dataset.schema) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index f01306f89cb5f..e0d433f566c25 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -65,7 +65,7 @@ class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext { } test("StopWordsRemover with additional words") { - val stopWords = StopWords.EnglishStopWords ++ Array("python", "scala") + val stopWords = StopWords.English ++ Array("python", "scala") val remover = new StopWordsRemover() .setInputCol("raw") .setOutputCol("filtered") diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 0626281e200a1..d955307e27efd 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -22,7 +22,7 @@ from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.param.shared import * from pyspark.ml.util import keyword_only -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm from pyspark.mllib.common import inherit_doc from pyspark.mllib.linalg import _convert_to_vector @@ -30,7 +30,7 @@ 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec', - 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel'] + 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel', 'StopWordsRemover'] @inherit_doc @@ -933,6 +933,75 @@ class StringIndexerModel(JavaModel): """ +class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + A feature transformer that filters out stop words from input. + Note: null values from input array are preserved unless adding null to stopWords explicitly. + """ + # a placeholder to make the stopwords show up in generated doc + stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out") + caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " + + "comparison over the stop words") + + @keyword_only + def __init__(self, inputCol=None, outputCol=None, stopWords=None, + caseSensitive=False): + """ + __init__(self, inputCol=None, outputCol=None, stopWords=None,\ + caseSensitive=false) + """ + super(StopWordsRemover, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover", + self.uid) + self.stopWords = Param(self, "stopWords", "The words to be filtered out") + self.caseSensitive = Param(self, "caseSensitive", "whether to do a case " + + "sensitive comparison over the stop words") + stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWords + defaultStopWords = stopWordsObj.English() + self._setDefault(stopWords=defaultStopWords) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, inputCol=None, outputCol=None, stopWords=None, + caseSensitive=False): + """ + setParams(self, inputCol="input", outputCol="output", stopWords=None,\ + caseSensitive=false) + Sets params for this StopWordRemover. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setStopWords(self, value): + """ + Specify the stopwords to be filtered. + """ + self._paramMap[self.stopWords] = value + return self + + def getStopWords(self): + """ + Get the stopwords. + """ + return self.getOrDefault(self.stopWords) + + def setCaseSensitive(self, value): + """ + Set whether to do a case sensitive comparison over the stop words + """ + self._paramMap[self.caseSensitive] = value + return self + + def getCaseSensitive(self): + """ + Get whether to do a case sensitive comparison over the stop words. + """ + return self.getOrDefault(self.caseSensitive) + + @inherit_doc @ignore_unicode_prefix class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 60e4237293adc..b892318f50bd9 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -31,7 +31,7 @@ import unittest from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase -from pyspark.sql import DataFrame, SQLContext +from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql.functions import rand from pyspark.ml.evaluation import RegressionEvaluator from pyspark.ml.param import Param, Params @@ -258,7 +258,7 @@ def test_idf(self): def test_ngram(self): sqlContext = SQLContext(self.sc) dataset = sqlContext.createDataFrame([ - ([["a", "b", "c", "d", "e"]])], ["input"]) + Row(input=["a", "b", "c", "d", "e"])]) ngram0 = NGram(n=4, inputCol="input", outputCol="output") self.assertEqual(ngram0.getN(), 4) self.assertEqual(ngram0.getInputCol(), "input") @@ -266,6 +266,22 @@ def test_ngram(self): transformedDF = ngram0.transform(dataset) self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"]) + def test_stopwordsremover(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([Row(input=["a", "panda"])]) + stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output") + # Default + self.assertEquals(stopWordRemover.getInputCol(), "input") + transformedDF = stopWordRemover.transform(dataset) + self.assertEquals(transformedDF.head().output, ["panda"]) + # Custom + stopwords = ["panda"] + stopWordRemover.setStopWords(stopwords) + self.assertEquals(stopWordRemover.getInputCol(), "input") + self.assertEquals(stopWordRemover.getStopWords(), stopwords) + transformedDF = stopWordRemover.transform(dataset) + self.assertEquals(transformedDF.head().output, ["a"]) + class HasInducedError(Params): From 3f63bd6023edcc9af268933a235f34e10bc3d2ba Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 1 Sep 2015 20:06:01 +0100 Subject: [PATCH 162/802] [SPARK-10398] [DOCS] Migrate Spark download page to use new lua mirroring scripts Migrate Apache download closer.cgi refs to new closer.lua This is the bit of the change that affects the project docs; I'm implementing the changes to the Apache site separately. Author: Sean Owen Closes #8557 from srowen/SPARK-10398. --- docker/spark-mesos/Dockerfile | 2 +- docs/running-on-mesos.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/spark-mesos/Dockerfile b/docker/spark-mesos/Dockerfile index b90aef3655dee..fb3f267fe5c78 100644 --- a/docker/spark-mesos/Dockerfile +++ b/docker/spark-mesos/Dockerfile @@ -24,7 +24,7 @@ RUN apt-get update && \ apt-get install -y python libnss3 openjdk-7-jre-headless curl RUN mkdir /opt/spark && \ - curl http://www.apache.org/dyn/closer.cgi/spark/spark-1.4.0/spark-1.4.0-bin-hadoop2.4.tgz \ + curl http://www.apache.org/dyn/closer.lua/spark/spark-1.4.0/spark-1.4.0-bin-hadoop2.4.tgz \ | tar -xzC /opt ENV SPARK_HOME /opt/spark ENV MESOS_NATIVE_JAVA_LIBRARY /usr/local/lib/libmesos.so diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index cfd219ab02e26..f36921ae30c2f 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -45,7 +45,7 @@ frameworks. You can install Mesos either from source or using prebuilt packages To install Apache Mesos from source, follow these steps: 1. Download a Mesos release from a - [mirror](http://www.apache.org/dyn/closer.cgi/mesos/{{site.MESOS_VERSION}}/) + [mirror](http://www.apache.org/dyn/closer.lua/mesos/{{site.MESOS_VERSION}}/) 2. Follow the Mesos [Getting Started](http://mesos.apache.org/gettingstarted) page for compiling and installing Mesos From ec012805337926e56343be2761a1037296446880 Mon Sep 17 00:00:00 2001 From: zhuol Date: Tue, 1 Sep 2015 11:14:59 -1000 Subject: [PATCH 163/802] [SPARK-4223] [CORE] Support * in acls. SPARK-4223. Currently we support setting view and modify acls but you have to specify a list of users. It would be nice to support * meaning all users have access. Manual tests to verify that: "*" works for any user in: a. Spark ui: view and kill stage. Done. b. Spark history server. Done. c. Yarn application killing. Done. Author: zhuol Closes #8398 from zhuoliu/4223. --- .../org/apache/spark/SecurityManager.scala | 26 ++++++++++-- .../apache/spark/SecurityManagerSuite.scala | 41 +++++++++++++++++++ docs/configuration.md | 9 ++-- 3 files changed, 69 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 673ef49e7c1c5..746d2081d4393 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -310,7 +310,16 @@ private[spark] class SecurityManager(sparkConf: SparkConf) setViewAcls(Set[String](defaultUser), allowedUsers) } - def getViewAcls: String = viewAcls.mkString(",") + /** + * Checking the existence of "*" is necessary as YARN can't recognize the "*" in "defaultuser,*" + */ + def getViewAcls: String = { + if (viewAcls.contains("*")) { + "*" + } else { + viewAcls.mkString(",") + } + } /** * Admin acls should be set before the view or modify acls. If you modify the admin @@ -321,7 +330,16 @@ private[spark] class SecurityManager(sparkConf: SparkConf) logInfo("Changing modify acls to: " + modifyAcls.mkString(",")) } - def getModifyAcls: String = modifyAcls.mkString(",") + /** + * Checking the existence of "*" is necessary as YARN can't recognize the "*" in "defaultuser,*" + */ + def getModifyAcls: String = { + if (modifyAcls.contains("*")) { + "*" + } else { + modifyAcls.mkString(",") + } + } /** * Admin acls should be set before the view or modify acls. If you modify the admin @@ -394,7 +412,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) def checkUIViewPermissions(user: String): Boolean = { logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " viewAcls=" + viewAcls.mkString(",")) - !aclsEnabled || user == null || viewAcls.contains(user) + !aclsEnabled || user == null || viewAcls.contains(user) || viewAcls.contains("*") } /** @@ -409,7 +427,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) def checkModifyPermissions(user: String): Boolean = { logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " modifyAcls=" + modifyAcls.mkString(",")) - !aclsEnabled || user == null || modifyAcls.contains(user) + !aclsEnabled || user == null || modifyAcls.contains(user) || modifyAcls.contains("*") } diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index f34aefca4eb18..f29160d834082 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -125,6 +125,47 @@ class SecurityManagerSuite extends SparkFunSuite { } + test("set security with * in acls") { + val conf = new SparkConf + conf.set("spark.ui.acls.enable", "true") + conf.set("spark.admin.acls", "user1,user2") + conf.set("spark.ui.view.acls", "*") + conf.set("spark.modify.acls", "user4") + + val securityManager = new SecurityManager(conf) + assert(securityManager.aclsEnabled() === true) + + // check for viewAcls with * + assert(securityManager.checkUIViewPermissions("user1") === true) + assert(securityManager.checkUIViewPermissions("user5") === true) + assert(securityManager.checkUIViewPermissions("user6") === true) + assert(securityManager.checkModifyPermissions("user4") === true) + assert(securityManager.checkModifyPermissions("user7") === false) + assert(securityManager.checkModifyPermissions("user8") === false) + + // check for modifyAcls with * + securityManager.setModifyAcls(Set("user4"), "*") + assert(securityManager.checkModifyPermissions("user7") === true) + assert(securityManager.checkModifyPermissions("user8") === true) + + securityManager.setAdminAcls("user1,user2") + securityManager.setModifyAcls(Set("user1"), "user2") + securityManager.setViewAcls(Set("user1"), "user2") + assert(securityManager.checkUIViewPermissions("user5") === false) + assert(securityManager.checkUIViewPermissions("user6") === false) + assert(securityManager.checkModifyPermissions("user7") === false) + assert(securityManager.checkModifyPermissions("user8") === false) + + // check for adminAcls with * + securityManager.setAdminAcls("user1,*") + securityManager.setModifyAcls(Set("user1"), "user2") + securityManager.setViewAcls(Set("user1"), "user2") + assert(securityManager.checkUIViewPermissions("user5") === true) + assert(securityManager.checkUIViewPermissions("user6") === true) + assert(securityManager.checkModifyPermissions("user7") === true) + assert(securityManager.checkModifyPermissions("user8") === true) + } + test("ssl on setup") { val conf = SSLSampleConfigs.sparkSSLConfig() val expectedAlgorithms = Set( diff --git a/docs/configuration.md b/docs/configuration.md index 77c5cbc7b3196..fb0315ce7c3cc 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1286,7 +1286,8 @@ Apart from these, the following properties are also available, and may be useful Comma separated list of users/administrators that have view and modify access to all Spark jobs. This can be used if you run on a shared cluster and have a set of administrators or devs who - help debug when things work. + help debug when things work. Putting a "*" in the list means any user can have the priviledge + of admin. @@ -1327,7 +1328,8 @@ Apart from these, the following properties are also available, and may be useful Empty Comma separated list of users that have modify access to the Spark job. By default only the - user that started the Spark job has access to modify it (kill it for example). + user that started the Spark job has access to modify it (kill it for example). Putting a "*" in + the list means any user can have access to modify it. @@ -1349,7 +1351,8 @@ Apart from these, the following properties are also available, and may be useful Empty Comma separated list of users that have view access to the Spark web ui. By default only the - user that started the Spark job has view access. + user that started the Spark job has view access. Putting a "*" in the list means any user can + have view access to this Spark job. From bf550a4b551b6dd18fea3eb3f70497f9a6ad8e6c Mon Sep 17 00:00:00 2001 From: 0x0FFF Date: Tue, 1 Sep 2015 14:34:59 -0700 Subject: [PATCH 164/802] [SPARK-10162] [SQL] Fix the timezone omitting for PySpark Dataframe filter function This PR addresses [SPARK-10162](https://issues.apache.org/jira/browse/SPARK-10162) The issue is with DataFrame filter() function, if datetime.datetime is passed to it: * Timezone information of this datetime is ignored * This datetime is assumed to be in local timezone, which depends on the OS timezone setting Fix includes both code change and regression test. Problem reproduction code on master: ```python import pytz from datetime import datetime from pyspark.sql import * from pyspark.sql.types import * sqc = SQLContext(sc) df = sqc.createDataFrame([], StructType([StructField("dt", TimestampType())])) m1 = pytz.timezone('UTC') m2 = pytz.timezone('Etc/GMT+3') df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain() df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain() ``` It gives the same timestamp ignoring time zone: ``` >>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain() Filter (dt#0 > 946713600000000) Scan PhysicalRDD[dt#0] >>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain() Filter (dt#0 > 946713600000000) Scan PhysicalRDD[dt#0] ``` After the fix: ``` >>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m1)).explain() Filter (dt#0 > 946684800000000) Scan PhysicalRDD[dt#0] >>> df.filter(df.dt > datetime(2000, 01, 01, tzinfo=m2)).explain() Filter (dt#0 > 946695600000000) Scan PhysicalRDD[dt#0] ``` PR [8536](https://github.com/apache/spark/pull/8536) was occasionally closed by me dropping the repo Author: 0x0FFF Closes #8555 from 0x0FFF/SPARK-10162. --- python/pyspark/sql/tests.py | 26 ++++++++++++++++++-------- python/pyspark/sql/types.py | 7 +++++-- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index cd32e26c64f22..59a891bd7c420 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -50,16 +50,17 @@ from pyspark.sql.utils import AnalysisException, IllegalArgumentException -class UTC(datetime.tzinfo): - """UTC""" - ZERO = datetime.timedelta(0) +class UTCOffsetTimezone(datetime.tzinfo): + """ + Specifies timezone in UTC offset + """ + + def __init__(self, offset=0): + self.ZERO = datetime.timedelta(hours=offset) def utcoffset(self, dt): return self.ZERO - def tzname(self, dt): - return "UTC" - def dst(self, dt): return self.ZERO @@ -841,13 +842,22 @@ def test_filter_with_datetime(self): self.assertEqual(0, df.filter(df.date > date).count()) self.assertEqual(0, df.filter(df.time > time).count()) + def test_filter_with_datetime_timezone(self): + dt1 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(0)) + dt2 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(1)) + row = Row(date=dt1) + df = self.sqlCtx.createDataFrame([row]) + self.assertEqual(0, df.filter(df.date == dt2).count()) + self.assertEqual(1, df.filter(df.date > dt2).count()) + self.assertEqual(0, df.filter(df.date < dt2).count()) + def test_time_with_timezone(self): day = datetime.date.today() now = datetime.datetime.now() ts = time.mktime(now.timetuple()) # class in __main__ is not serializable - from pyspark.sql.tests import UTC - utc = UTC() + from pyspark.sql.tests import UTCOffsetTimezone + utc = UTCOffsetTimezone() utcnow = datetime.datetime.utcfromtimestamp(ts) # without microseconds # add microseconds to utcnow (keeping year,month,day,hour,minute,second) utcnow = datetime.datetime(*(utcnow.timetuple()[:6] + (now.microsecond, utc))) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 94e581a78364c..f84d08d7098ad 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1290,8 +1290,11 @@ def can_convert(self, obj): def convert(self, obj, gateway_client): Timestamp = JavaClass("java.sql.Timestamp", gateway_client) - return Timestamp(int(time.mktime(obj.timetuple())) * 1000 + obj.microsecond // 1000) - + seconds = (calendar.timegm(obj.utctimetuple()) if obj.tzinfo + else time.mktime(obj.timetuple())) + t = Timestamp(int(seconds) * 1000) + t.setNanos(obj.microsecond * 1000) + return t # datetime is a subclass of date, we should register DatetimeConverter first register_input_converter(DatetimeConverter()) From 00d9af5e190475affffb8b50467fcddfc40f50dc Mon Sep 17 00:00:00 2001 From: 0x0FFF Date: Tue, 1 Sep 2015 14:58:49 -0700 Subject: [PATCH 165/802] [SPARK-10392] [SQL] Pyspark - Wrong DateType support on JDBC connection This PR addresses issue [SPARK-10392](https://issues.apache.org/jira/browse/SPARK-10392) The problem is that for "start of epoch" date (01 Jan 1970) PySpark class DateType returns 0 instead of the `datetime.date` due to implementation of its return statement Issue reproduction on master: ``` >>> from pyspark.sql.types import * >>> a = DateType() >>> a.fromInternal(0) 0 >>> a.fromInternal(1) datetime.date(1970, 1, 2) ``` Author: 0x0FFF Closes #8556 from 0x0FFF/SPARK-10392. --- python/pyspark/sql/tests.py | 5 +++++ python/pyspark/sql/types.py | 6 ++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 59a891bd7c420..fc778631d93a3 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -168,6 +168,11 @@ def test_decimal_type(self): t3 = DecimalType(8) self.assertNotEqual(t2, t3) + # regression test for SPARK-10392 + def test_datetype_equal_zero(self): + dt = DateType() + self.assertEqual(dt.fromInternal(0), datetime.date(1970, 1, 1)) + class SQLTests(ReusedPySparkTestCase): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index f84d08d7098ad..8bd58d69eeecd 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -168,10 +168,12 @@ def needConversion(self): return True def toInternal(self, d): - return d and d.toordinal() - self.EPOCH_ORDINAL + if d is not None: + return d.toordinal() - self.EPOCH_ORDINAL def fromInternal(self, v): - return v and datetime.date.fromordinal(v + self.EPOCH_ORDINAL) + if v is not None: + return datetime.date.fromordinal(v + self.EPOCH_ORDINAL) class TimestampType(AtomicType): From c3b881a7d7e4736f7131ff002a80e25def1f63af Mon Sep 17 00:00:00 2001 From: Chuan Shao Date: Wed, 2 Sep 2015 11:02:27 -0700 Subject: [PATCH 166/802] [SPARK-7336] [HISTORYSERVER] Fix bug that applications status incorrect on JobHistory UI. Author: ArcherShao Closes #5886 from ArcherShao/SPARK-7336. --- .../deploy/history/FsHistoryProvider.scala | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index e573ff16c50a3..a5755eac36396 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.history import java.io.{BufferedInputStream, FileNotFoundException, InputStream, IOException, OutputStream} +import java.util.UUID import java.util.concurrent.{ExecutorService, Executors, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} @@ -73,7 +74,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // The modification time of the newest log detected during the last scan. This is used // to ignore logs that are older during subsequent scans, to avoid processing data that // is already known. - private var lastModifiedTime = -1L + private var lastScanTime = -1L // Mapping of application IDs to their metadata, in descending end time order. Apps are inserted // into the map in order, so the LinkedHashMap maintains the correct ordering. @@ -179,15 +180,14 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) */ private[history] def checkForLogs(): Unit = { try { + val newLastScanTime = getNewLastScanTime() val statusList = Option(fs.listStatus(new Path(logDir))).map(_.toSeq) .getOrElse(Seq[FileStatus]()) - var newLastModifiedTime = lastModifiedTime val logInfos: Seq[FileStatus] = statusList .filter { entry => try { getModificationTime(entry).map { time => - newLastModifiedTime = math.max(newLastModifiedTime, time) - time >= lastModifiedTime + time >= lastScanTime }.getOrElse(false) } catch { case e: AccessControlException => @@ -224,12 +224,29 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - lastModifiedTime = newLastModifiedTime + lastScanTime = newLastScanTime } catch { case e: Exception => logError("Exception in checking for event log updates", e) } } + private def getNewLastScanTime(): Long = { + val fileName = "." + UUID.randomUUID().toString + val path = new Path(logDir, fileName) + val fos = fs.create(path) + + try { + fos.close() + fs.getFileStatus(path).getModificationTime + } catch { + case e: Exception => + logError("Exception encountered when attempting to update last scan time", e) + lastScanTime + } finally { + fs.delete(path) + } + } + override def writeEventLogs( appId: String, attemptId: Option[String], From 56c4c172e99a5e14f4bc3308e7ff36d94113b63e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 2 Sep 2015 11:13:17 -0700 Subject: [PATCH 167/802] [SPARK-10034] [SQL] add regression test for Sort on Aggregate Before #8371, there was a bug for `Sort` on `Aggregate` that we can't use aggregate expressions named `_aggOrdering` and can't use more than one ordering expressions which contains aggregate functions. The reason of this bug is that: The aggregate expression in `SortOrder` never get resolved, we alias it with `_aggOrdering` and call `toAttribute` which gives us an `UnresolvedAttribute`. So actually we are referencing aggregate expression by name, not by exprId like we thought. And if there is already an aggregate expression named `_aggOrdering` or there are more than one ordering expressions having aggregate functions, we will have conflict names and can't search by name. However, after #8371 got merged, the `SortOrder`s are guaranteed to be resolved and we are always referencing aggregate expression by exprId. The Bug doesn't exist anymore and this PR add regression tests for it. Author: Wenchen Fan Closes #8231 from cloud-fan/sort-agg. --- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 8 ++++++++ .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 10 ++++++++++ 2 files changed, 18 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 284fff184085a..a4871e247cff7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -887,4 +887,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { .select(struct($"b")) .collect() } + + test("SPARK-10034: Sort on Aggregate with aggregation expression named 'aggOrdering'") { + val df = Seq(1 -> 2).toDF("i", "j") + val query = df.groupBy('i) + .agg(max('j).as("aggOrdering")) + .orderBy(sum('j)) + checkAnswer(query, Row(1, 2)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 9e172b2c264cb..28201073a2d7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1490,6 +1490,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { """.stripMargin), Row(3) :: Row(7) :: Row(11) :: Row(15) :: Nil) + checkAnswer( + sql( + """ + |SELECT sum(b) + |FROM orderByData + |GROUP BY a + |ORDER BY sum(b), max(b) + """.stripMargin), + Row(3) :: Row(7) :: Row(11) :: Row(15) :: Nil) + checkAnswer( sql( """ From fc48307797912dc1d53893dce741ddda8630957b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 2 Sep 2015 11:32:27 -0700 Subject: [PATCH 168/802] [SPARK-10389] [SQL] support order by non-attribute grouping expression on Aggregate For example, we can write `SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY key + 1` in PostgreSQL, and we should support this in Spark SQL. Author: Wenchen Fan Closes #8548 from cloud-fan/support-order-by-non-attribute. --- .../sql/catalyst/analysis/Analyzer.scala | 72 ++++++++++--------- .../org/apache/spark/sql/SQLQuerySuite.scala | 19 +++-- 2 files changed, 52 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1a5de15c61f86..591747b45c376 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -560,43 +560,47 @@ class Analyzer( filter } - case sort @ Sort(sortOrder, global, - aggregate @ Aggregate(grouping, originalAggExprs, child)) + case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved && !sort.resolved => // Try resolving the ordering as though it is in the aggregate clause. try { - val aliasedOrder = sortOrder.map(o => Alias(o.child, "aggOrder")()) - val aggregatedOrdering = Aggregate(grouping, aliasedOrder, child) - val resolvedOperator: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] - def resolvedAggregateOrdering = resolvedOperator.aggregateExpressions - - // Expressions that have an aggregate can be pushed down. - val needsAggregate = resolvedAggregateOrdering.exists(containsAggregate) - - // Attribute references, that are missing from the order but are present in the grouping - // expressions can also be pushed down. - val requiredAttributes = resolvedAggregateOrdering.map(_.references).reduce(_ ++ _) - val missingAttributes = requiredAttributes -- aggregate.outputSet - val validPushdownAttributes = - missingAttributes.filter(a => grouping.exists(a.semanticEquals)) - - // If resolution was successful and we see the ordering either has an aggregate in it or - // it is missing something that is projected away by the aggregate, add the ordering - // the original aggregate operator. - if (resolvedOperator.resolved && (needsAggregate || validPushdownAttributes.nonEmpty)) { - val evaluatedOrderings: Seq[SortOrder] = sortOrder.zip(resolvedAggregateOrdering).map { - case (order, evaluated) => order.copy(child = evaluated.toAttribute) - } - val aggExprsWithOrdering: Seq[NamedExpression] = - resolvedAggregateOrdering ++ originalAggExprs - - Project(aggregate.output, - Sort(evaluatedOrderings, global, - aggregate.copy(aggregateExpressions = aggExprsWithOrdering))) - } else { - sort + val aliasedOrdering = sortOrder.map(o => Alias(o.child, "aggOrder")()) + val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) + val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] + val resolvedAliasedOrdering: Seq[Alias] = + resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]] + + // If we pass the analysis check, then the ordering expressions should only reference to + // aggregate expressions or grouping expressions, and it's safe to push them down to + // Aggregate. + checkAnalysis(resolvedAggregate) + + val originalAggExprs = aggregate.aggregateExpressions.map( + CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) + + // If the ordering expression is same with original aggregate expression, we don't need + // to push down this ordering expression and can reference the original aggregate + // expression instead. + val needsPushDown = ArrayBuffer.empty[NamedExpression] + val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map { + case (evaluated, order) => + val index = originalAggExprs.indexWhere { + case Alias(child, _) => child semanticEquals evaluated.child + case other => other semanticEquals evaluated.child + } + + if (index == -1) { + needsPushDown += evaluated + order.copy(child = evaluated.toAttribute) + } else { + order.copy(child = originalAggExprs(index).toAttribute) + } } + + Project(aggregate.output, + Sort(evaluatedOrderings, global, + aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) } catch { // Attempting to resolve in the aggregate can result in ambiguity. When this happens, // just return the original plan. @@ -605,9 +609,7 @@ class Analyzer( } protected def containsAggregate(condition: Expression): Boolean = { - condition - .collect { case ae: AggregateExpression => ae } - .nonEmpty + condition.find(_.isInstanceOf[AggregateExpression]).isDefined } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 28201073a2d7b..0ef25fe0faef0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1722,9 +1722,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-10130 type coercion for IF should have children resolved first") { - val df = Seq((1, 1), (-1, 1)).toDF("key", "value") - df.registerTempTable("src") - checkAnswer( - sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0))) + withTempTable("src") { + Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src") + checkAnswer( + sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0))) + } + } + + test("SPARK-10389: order by non-attribute grouping expression on Aggregate") { + withTempTable("src") { + Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src") + checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY key + 1"), + Seq(Row(1), Row(1))) + checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY (key + 1) * 2"), + Seq(Row(1), Row(1))) + } } } From 2da3a9e98e5d129d4507b5db01bba5ee9558d28e Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 2 Sep 2015 12:53:24 -0700 Subject: [PATCH 169/802] [SPARK-10004] [SHUFFLE] Perform auth checks when clients read shuffle data. To correctly isolate applications, when requests to read shuffle data arrive at the shuffle service, proper authorization checks need to be performed. This change makes sure that only the application that created the shuffle data can read from it. Such checks are only enabled when "spark.authenticate" is enabled, otherwise there's no secure way to make sure that the client is really who it says it is. Author: Marcelo Vanzin Closes #8218 from vanzin/SPARK-10004. --- .../network/netty/NettyBlockRpcServer.scala | 3 +- .../netty/NettyBlockTransferService.scala | 2 +- network/common/pom.xml | 4 + .../spark/network/client/TransportClient.java | 22 +++ .../network/sasl/SaslClientBootstrap.java | 2 + .../spark/network/sasl/SaslRpcHandler.java | 1 + .../server/OneForOneStreamManager.java | 31 +++- .../spark/network/server/StreamManager.java | 9 + .../server/TransportRequestHandler.java | 1 + .../shuffle/ExternalShuffleBlockHandler.java | 16 +- .../network/sasl/SaslIntegrationSuite.java | 163 +++++++++++++++--- .../ExternalShuffleBlockHandlerSuite.java | 2 +- project/MimaExcludes.scala | 1 + 13 files changed, 221 insertions(+), 36 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 7c170a742fb64..76968249fb625 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -38,6 +38,7 @@ import org.apache.spark.storage.{BlockId, StorageLevel} * is equivalent to one Spark-level shuffle block. */ class NettyBlockRpcServer( + appId: String, serializer: Serializer, blockManager: BlockDataManager) extends RpcHandler with Logging { @@ -55,7 +56,7 @@ class NettyBlockRpcServer( case openBlocks: OpenBlocks => val blocks: Seq[ManagedBuffer] = openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) - val streamId = streamManager.registerStream(blocks.iterator.asJava) + val streamId = streamManager.registerStream(appId, blocks.iterator.asJava) logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index ff8aae9ebe9f0..d5ad2c9ad00e8 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -49,7 +49,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage private[this] var appId: String = _ override def init(blockDataManager: BlockDataManager): Unit = { - val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) + val rpcHandler = new NettyBlockRpcServer(conf.getAppId, serializer, blockDataManager) var serverBootstrap: Option[TransportServerBootstrap] = None var clientBootstrap: Option[TransportClientBootstrap] = None if (authEnabled) { diff --git a/network/common/pom.xml b/network/common/pom.xml index 7dc3068ab8cb7..4141fcb8267a5 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -48,6 +48,10 @@ slf4j-api provided
+ + com.google.code.findbugs + jsr305 + {% highlight python %} from pyspark.ml.regression import LinearRegression -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.util import MLUtils # Load training data -training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() +training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) From b656e6134fc5cd27e1fe6b6ab30fd7633cab0b14 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 11 Sep 2015 08:50:35 -0700 Subject: [PATCH 251/802] [SPARK-10026] [ML] [PySpark] Implement some common Params for regression in PySpark LinearRegression and LogisticRegression lack of some Params for Python, and some Params are not shared classes which lead we need to write them for each class. These kinds of Params are list here: ```scala HasElasticNetParam HasFitIntercept HasStandardization HasThresholds ``` Here we implement them in shared params at Python side and make LinearRegression/LogisticRegression parameters peer with Scala one. Author: Yanbo Liang Closes #8508 from yanboliang/spark-10026. --- python/pyspark/ml/classification.py | 75 ++---------- .../ml/param/_shared_params_code_gen.py | 11 +- python/pyspark/ml/param/shared.py | 111 ++++++++++++++++++ python/pyspark/ml/regression.py | 42 ++----- 4 files changed, 143 insertions(+), 96 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 83f808efc3bf0..22bdd1b322aca 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -31,7 +31,8 @@ @inherit_doc class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol): + HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol, + HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds): """ Logistic regression. Currently, this class only supports binary classification. @@ -65,17 +66,6 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti """ # a placeholder to make it appear in the generated doc - elasticNetParam = \ - Param(Params._dummy(), "elasticNetParam", - "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + - "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") - fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.") - thresholds = Param(Params._dummy(), "thresholds", - "Thresholds in multi-class classification" + - " to adjust the probability of predicting each class." + - " Array must have length equal to the number of classes, with values >= 0." + - " The class with largest value p/t is predicted, where p is the original" + - " probability of that class and t is the class' threshold.") threshold = Param(Params._dummy(), "threshold", "Threshold in binary classification prediction, in range [0, 1]." + " If threshold and thresholds are both set, they must match.") @@ -83,40 +73,23 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=0.5, thresholds=None, - probabilityCol="probability", rawPredictionCol="rawPrediction"): + threshold=0.5, thresholds=None, probabilityCol="probability", + rawPredictionCol="rawPrediction", standardization=True): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=0.5, thresholds=None, \ - probabilityCol="probability", rawPredictionCol="rawPrediction") + threshold=0.5, thresholds=None, probabilityCol="probability", \ + rawPredictionCol="rawPrediction", standardization=True) If the threshold and thresholds Params are both set, they must be equivalent. """ super(LogisticRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.LogisticRegression", self.uid) - #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty - # is an L2 penalty. For alpha = 1, it is an L1 penalty. - self.elasticNetParam = \ - Param(self, "elasticNetParam", - "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + - "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") - #: param for whether to fit an intercept term. - self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") #: param for threshold in binary classification, in range [0, 1]. self.threshold = Param(self, "threshold", "Threshold in binary classification prediction, in range [0, 1]." + " If threshold and thresholds are both set, they must match.") - #: param for thresholds or cutoffs in binary or multiclass classification - self.thresholds = \ - Param(self, "thresholds", - "Thresholds in multi-class classification" + - " to adjust the probability of predicting each class." + - " Array must have length equal to the number of classes, with values >= 0." + - " The class with largest value p/t is predicted, where p is the original" + - " probability of that class and t is the class' threshold.") - self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6, - fitIntercept=True, threshold=0.5) + self._setDefault(maxIter=100, regParam=0.1, tol=1E-6, threshold=0.5) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) self._checkThresholdConsistency() @@ -124,13 +97,13 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @keyword_only def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, - threshold=0.5, thresholds=None, - probabilityCol="probability", rawPredictionCol="rawPrediction"): + threshold=0.5, thresholds=None, probabilityCol="probability", + rawPredictionCol="rawPrediction", standardization=True): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ - threshold=0.5, thresholds=None, \ - probabilityCol="probability", rawPredictionCol="rawPrediction") + threshold=0.5, thresholds=None, probabilityCol="probability", \ + rawPredictionCol="rawPrediction", standardization=True) Sets params for logistic regression. If the threshold and thresholds Params are both set, they must be equivalent. """ @@ -142,32 +115,6 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return LogisticRegressionModel(java_model) - def setElasticNetParam(self, value): - """ - Sets the value of :py:attr:`elasticNetParam`. - """ - self._paramMap[self.elasticNetParam] = value - return self - - def getElasticNetParam(self): - """ - Gets the value of elasticNetParam or its default value. - """ - return self.getOrDefault(self.elasticNetParam) - - def setFitIntercept(self, value): - """ - Sets the value of :py:attr:`fitIntercept`. - """ - self._paramMap[self.fitIntercept] = value - return self - - def getFitIntercept(self): - """ - Gets the value of fitIntercept or its default value. - """ - return self.getOrDefault(self.fitIntercept) - def setThreshold(self, value): """ Sets the value of :py:attr:`threshold`. diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 926375e44871d..5b39e5dd4e25b 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -124,7 +124,16 @@ def get$Name(self): ("stepSize", "Step size to be used for each iteration of optimization.", None), ("handleInvalid", "how to handle invalid entries. Options are skip (which will filter " + "out rows with bad values), or error (which will throw an errror). More options may be " + - "added later.", None)] + "added later.", None), + ("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + + "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0"), + ("fitIntercept", "whether to fit an intercept term.", "True"), + ("standardization", "whether to standardize the training features before fitting the " + + "model.", "True"), + ("thresholds", "Thresholds in multi-class classification to adjust the probability of " + + "predicting each class. Array must have length equal to the number of classes, with " + + "values >= 0. The class with largest value p/t is predicted, where p is the original " + + "probability of that class and t is the class' threshold.", None)] code = [] for name, doc, defaultValueStr in shared: param_code = _gen_param_header(name, doc, defaultValueStr) diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 682170aee85fb..af1218128602b 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -459,6 +459,117 @@ def getHandleInvalid(self): return self.getOrDefault(self.handleInvalid) +class HasElasticNetParam(Params): + """ + Mixin for param elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.. + """ + + # a placeholder to make it appear in the generated doc + elasticNetParam = Param(Params._dummy(), "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") + + def __init__(self): + super(HasElasticNetParam, self).__init__() + #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. + self.elasticNetParam = Param(self, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") + self._setDefault(elasticNetParam=0.0) + + def setElasticNetParam(self, value): + """ + Sets the value of :py:attr:`elasticNetParam`. + """ + self._paramMap[self.elasticNetParam] = value + return self + + def getElasticNetParam(self): + """ + Gets the value of elasticNetParam or its default value. + """ + return self.getOrDefault(self.elasticNetParam) + + +class HasFitIntercept(Params): + """ + Mixin for param fitIntercept: whether to fit an intercept term.. + """ + + # a placeholder to make it appear in the generated doc + fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.") + + def __init__(self): + super(HasFitIntercept, self).__init__() + #: param for whether to fit an intercept term. + self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") + self._setDefault(fitIntercept=True) + + def setFitIntercept(self, value): + """ + Sets the value of :py:attr:`fitIntercept`. + """ + self._paramMap[self.fitIntercept] = value + return self + + def getFitIntercept(self): + """ + Gets the value of fitIntercept or its default value. + """ + return self.getOrDefault(self.fitIntercept) + + +class HasStandardization(Params): + """ + Mixin for param standardization: whether to standardize the training features before fitting the model.. + """ + + # a placeholder to make it appear in the generated doc + standardization = Param(Params._dummy(), "standardization", "whether to standardize the training features before fitting the model.") + + def __init__(self): + super(HasStandardization, self).__init__() + #: param for whether to standardize the training features before fitting the model. + self.standardization = Param(self, "standardization", "whether to standardize the training features before fitting the model.") + self._setDefault(standardization=True) + + def setStandardization(self, value): + """ + Sets the value of :py:attr:`standardization`. + """ + self._paramMap[self.standardization] = value + return self + + def getStandardization(self): + """ + Gets the value of standardization or its default value. + """ + return self.getOrDefault(self.standardization) + + +class HasThresholds(Params): + """ + Mixin for param thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.. + """ + + # a placeholder to make it appear in the generated doc + thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.") + + def __init__(self): + super(HasThresholds, self).__init__() + #: param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold. + self.thresholds = Param(self, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.") + + def setThresholds(self, value): + """ + Sets the value of :py:attr:`thresholds`. + """ + self._paramMap[self.thresholds] = value + return self + + def getThresholds(self): + """ + Gets the value of thresholds or its default value. + """ + return self.getOrDefault(self.thresholds) + + class DecisionTreeParams(Params): """ Mixin for Decision Tree parameters. diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 44f60a769566d..a9503608b7f25 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -28,7 +28,8 @@ @inherit_doc class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - HasRegParam, HasTol): + HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept, + HasStandardization): """ Linear regression. @@ -63,38 +64,30 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction TypeError: Method setParams forces keyword arguments. """ - # a placeholder to make it appear in the generated doc - elasticNetParam = \ - Param(Params._dummy(), "elasticNetParam", - "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + - "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") - @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6): + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, + standardization=True): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6) + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ + standardization=True) """ super(LinearRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.LinearRegression", self.uid) - #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty - # is an L2 penalty. For alpha = 1, it is an L1 penalty. - self.elasticNetParam = \ - Param(self, "elasticNetParam", - "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty " + - "is an L2 penalty. For alpha = 1, it is an L1 penalty.") - self._setDefault(maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6) + self._setDefault(maxIter=100, regParam=0.0, tol=1e-6) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6): + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, + standardization=True): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6) + maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ + standardization=True) Sets params for linear regression. """ kwargs = self.setParams._input_kwargs @@ -103,19 +96,6 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return LinearRegressionModel(java_model) - def setElasticNetParam(self, value): - """ - Sets the value of :py:attr:`elasticNetParam`. - """ - self._paramMap[self.elasticNetParam] = value - return self - - def getElasticNetParam(self): - """ - Gets the value of elasticNetParam or its default value. - """ - return self.getOrDefault(self.elasticNetParam) - class LinearRegressionModel(JavaModel): """ From b01b26260625f0ba14e5f3010207666d62d93864 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 11 Sep 2015 08:52:28 -0700 Subject: [PATCH 252/802] [SPARK-9773] [ML] [PySpark] Add Python API for MultilayerPerceptronClassifier Add Python API for ```MultilayerPerceptronClassifier```. Author: Yanbo Liang Closes #8067 from yanboliang/SPARK-9773. --- .../MultilayerPerceptronClassifier.scala | 9 ++ python/pyspark/ml/classification.py | 132 +++++++++++++++++- 2 files changed, 140 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 82fc80c58054f..5f60dea91fcfa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.classification +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed} import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} @@ -181,6 +183,13 @@ class MultilayerPerceptronClassificationModel private[ml] ( private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights) + /** + * Returns layers in a Java List. + */ + private[ml] def javaLayers: java.util.List[Int] = { + layers.toList.asJava + } + /** * Predict label for the given features. * This internal method is used to implement [[transform()]] and output [[predictionCol]]. diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 22bdd1b322aca..88815e561f572 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -26,7 +26,8 @@ __all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassifier', 'DecisionTreeClassificationModel', 'GBTClassifier', 'GBTClassificationModel', 'RandomForestClassifier', 'RandomForestClassificationModel', 'NaiveBayes', - 'NaiveBayesModel'] + 'NaiveBayesModel', 'MultilayerPerceptronClassifier', + 'MultilayerPerceptronClassificationModel'] @inherit_doc @@ -755,6 +756,135 @@ def theta(self): return self._call_java("theta") +@inherit_doc +class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, + HasMaxIter, HasTol, HasSeed): + """ + Classifier trainer based on the Multilayer Perceptron. + Each layer has sigmoid activation function, output layer has softmax. + Number of inputs has to be equal to the size of feature vectors. + Number of outputs has to be equal to the total number of labels. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([ + ... (0.0, Vectors.dense([0.0, 0.0])), + ... (1.0, Vectors.dense([0.0, 1.0])), + ... (1.0, Vectors.dense([1.0, 0.0])), + ... (0.0, Vectors.dense([1.0, 1.0]))], ["label", "features"]) + >>> mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[2, 5, 2], blockSize=1, seed=11) + >>> model = mlp.fit(df) + >>> model.layers + [2, 5, 2] + >>> model.weights.size + 27 + >>> testDF = sqlContext.createDataFrame([ + ... (Vectors.dense([1.0, 0.0]),), + ... (Vectors.dense([0.0, 0.0]),)], ["features"]) + >>> model.transform(testDF).show() + +---------+----------+ + | features|prediction| + +---------+----------+ + |[1.0,0.0]| 1.0| + |[0.0,0.0]| 0.0| + +---------+----------+ + ... + """ + + # a placeholder to make it appear in the generated doc + layers = Param(Params._dummy(), "layers", "Sizes of layers from input layer to output layer " + + "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 " + + "neurons and output layer of 10 neurons, default is [1, 1].") + blockSize = Param(Params._dummy(), "blockSize", "Block size for stacking input data in " + + "matrices. Data is stacked within partitions. If block size is more than " + + "remaining data in a partition then it is adjusted to the size of this " + + "data. Recommended size is between 10 and 1000, default is 128.") + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxIter=100, tol=1e-4, seed=None, layers=[1, 1], blockSize=128) + """ + super(MultilayerPerceptronClassifier, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid) + self.layers = Param(self, "layers", "Sizes of layers from input layer to output layer " + + "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with " + + "100 neurons and output layer of 10 neurons, default is [1, 1].") + self.blockSize = Param(self, "blockSize", "Block size for stacking input data in " + + "matrices. Data is stacked within partitions. If block size is " + + "more than remaining data in a partition then it is adjusted to " + + "the size of this data. Recommended size is between 10 and 1000, " + + "default is 128.") + self._setDefault(maxIter=100, tol=1E-4, layers=[1, 1], blockSize=128) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxIter=100, tol=1e-4, seed=None, layers=[1, 1], blockSize=128) + Sets params for MultilayerPerceptronClassifier. + """ + kwargs = self.setParams._input_kwargs + if layers is None: + return self._set(**kwargs).setLayers([1, 1]) + else: + return self._set(**kwargs) + + def _create_model(self, java_model): + return MultilayerPerceptronClassificationModel(java_model) + + def setLayers(self, value): + """ + Sets the value of :py:attr:`layers`. + """ + self._paramMap[self.layers] = value + return self + + def getLayers(self): + """ + Gets the value of layers or its default value. + """ + return self.getOrDefault(self.layers) + + def setBlockSize(self, value): + """ + Sets the value of :py:attr:`blockSize`. + """ + self._paramMap[self.blockSize] = value + return self + + def getBlockSize(self): + """ + Gets the value of blockSize or its default value. + """ + return self.getOrDefault(self.blockSize) + + +class MultilayerPerceptronClassificationModel(JavaModel): + """ + Model fitted by MultilayerPerceptronClassifier. + """ + + @property + def layers(self): + """ + array of layer sizes including input and output layers. + """ + return self._call_java("javaLayers") + + @property + def weights(self): + """ + vector of initial weights for the model that consists of the weights of layers. + """ + return self._call_java("weights") + + if __name__ == "__main__": import doctest from pyspark.context import SparkContext From 960d2d0ac6b5a22242a922f87f745f7d1f736181 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 11 Sep 2015 08:53:40 -0700 Subject: [PATCH 253/802] [SPARK-10537] [ML] document LIBSVM source options in public API doc and some minor improvements We should document options in public API doc. Otherwise, it is hard to find out the options without looking at the code. I tried to make `DefaultSource` private and put the documentation to package doc. However, since then there exists no public class under `source.libsvm`, the Java package doc doesn't show up in the generated html file (http://bugs.java.com/bugdatabase/view_bug.do?bug_id=4492654). So I put the doc to `DefaultSource` instead. There are several minor updates in this PR: 1. Do `vectorType == "sparse"` only once. 2. Update `hashCode` and `equals`. 3. Remove inherited doc. 4. Delete temp dir in `afterAll`. Lewuathe Author: Xiangrui Meng Closes #8699 from mengxr/SPARK-10537. --- .../ml/source/libsvm/LibSVMRelation.scala | 71 ++++++++++++------- .../{ => libsvm}/JavaLibSVMRelationSuite.java | 24 +++---- .../{ => libsvm}/LibSVMRelationSuite.scala | 14 ++-- 3 files changed, 66 insertions(+), 43 deletions(-) rename mllib/src/test/java/org/apache/spark/ml/source/{ => libsvm}/JavaLibSVMRelationSuite.java (79%) rename mllib/src/test/scala/org/apache/spark/ml/source/{ => libsvm}/LibSVMRelationSuite.scala (88%) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index b12cb62a4ef15..1f627777fc68d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -21,12 +21,12 @@ import com.google.common.base.Objects import org.apache.spark.Logging import org.apache.spark.annotation.Since -import org.apache.spark.mllib.linalg.VectorUDT +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types.{StructType, StructField, DoubleType} -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{DataFrameReader, DataFrame, Row, SQLContext} import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** * LibSVMRelation provides the DataFrame constructed from LibSVM format data. @@ -35,7 +35,7 @@ import org.apache.spark.sql.sources._ * @param vectorType The type of vector. It can be 'sparse' or 'dense' * @param sqlContext The Spark SQLContext */ -private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String) +private[libsvm] class LibSVMRelation(val path: String, val numFeatures: Int, val vectorType: String) (@transient val sqlContext: SQLContext) extends BaseRelation with TableScan with Logging with Serializable { @@ -47,27 +47,56 @@ private[ml] class LibSVMRelation(val path: String, val numFeatures: Int, val vec override def buildScan(): RDD[Row] = { val sc = sqlContext.sparkContext val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures) - + val sparse = vectorType == "sparse" baseRdd.map { pt => - val features = if (vectorType == "dense") pt.features.toDense else pt.features.toSparse + val features = if (sparse) pt.features.toSparse else pt.features.toDense Row(pt.label, features) } } override def hashCode(): Int = { - Objects.hashCode(path, schema) + Objects.hashCode(path, Double.box(numFeatures), vectorType) } override def equals(other: Any): Boolean = other match { - case that: LibSVMRelation => (this.path == that.path) && this.schema.equals(that.schema) - case _ => false + case that: LibSVMRelation => + path == that.path && + numFeatures == that.numFeatures && + vectorType == that.vectorType + case _ => + false } - } /** - * This is used for creating DataFrame from LibSVM format file. - * The LibSVM file path must be specified to DefaultSource. + * `libsvm` package implements Spark SQL data source API for loading LIBSVM data as [[DataFrame]]. + * The loaded [[DataFrame]] has two columns: `label` containing labels stored as doubles and + * `features` containing feature vectors stored as [[Vector]]s. + * + * To use LIBSVM data source, you need to set "libsvm" as the format in [[DataFrameReader]] and + * optionally specify options, for example: + * {{{ + * // Scala + * val df = sqlContext.read.format("libsvm") + * .option("numFeatures", "780") + * .load("data/mllib/sample_libsvm_data.txt") + * + * // Java + * DataFrame df = sqlContext.read.format("libsvm") + * .option("numFeatures, "780") + * .load("data/mllib/sample_libsvm_data.txt"); + * }}} + * + * LIBSVM data source supports the following options: + * - "numFeatures": number of features. + * If unspecified or nonpositive, the number of features will be determined automatically at the + * cost of one additional pass. + * This is also useful when the dataset is already split into multiple files and you want to load + * them separately, because some features may not present in certain files, which leads to + * inconsistent feature dimensions. + * - "vectorType": feature vector type, "sparse" (default) or "dense". + * + * @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]] */ @Since("1.6.0") class DefaultSource extends RelationProvider with DataSourceRegister { @@ -75,24 +104,12 @@ class DefaultSource extends RelationProvider with DataSourceRegister { @Since("1.6.0") override def shortName(): String = "libsvm" - private def checkPath(parameters: Map[String, String]): String = { - require(parameters.contains("path"), "'path' must be specified") - parameters.get("path").get - } - - /** - * Returns a new base relation with the given parameters. - * Note: the parameters' keywords are case insensitive and this insensitivity is enforced - * by the Map that is passed to the function. - */ + @Since("1.6.0") override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]) : BaseRelation = { - val path = checkPath(parameters) + val path = parameters.getOrElse("path", + throw new IllegalArgumentException("'path' must be specified")) val numFeatures = parameters.getOrElse("numFeatures", "-1").toInt - /** - * featuresType can be selected "dense" or "sparse". - * This parameter decides the type of returned feature vector. - */ val vectorType = parameters.getOrElse("vectorType", "sparse") new LibSVMRelation(path, numFeatures, vectorType)(sqlContext) } diff --git a/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java similarity index 79% rename from mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java rename to mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java index 11fa4eec0ccf0..2976b38e45031 100644 --- a/mllib/src/test/java/org/apache/spark/ml/source/JavaLibSVMRelationSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.ml.source; +package org.apache.spark.ml.source.libsvm; import java.io.File; import java.io.IOException; @@ -42,34 +42,34 @@ */ public class JavaLibSVMRelationSuite { private transient JavaSparkContext jsc; - private transient SQLContext jsql; - private transient DataFrame dataset; + private transient SQLContext sqlContext; - private File tmpDir; - private File path; + private File tempDir; + private String path; @Before public void setUp() throws IOException { jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite"); - jsql = new SQLContext(jsc); - - tmpDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource"); - path = new File(tmpDir.getPath(), "part-00000"); + sqlContext = new SQLContext(jsc); + tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource"); + File file = new File(tempDir, "part-00000"); String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0"; - Files.write(s, path, Charsets.US_ASCII); + Files.write(s, file, Charsets.US_ASCII); + path = tempDir.toURI().toString(); } @After public void tearDown() { jsc.stop(); jsc = null; - Utils.deleteRecursively(tmpDir); + Utils.deleteRecursively(tempDir); } @Test public void verifyLibSVMDF() { - dataset = jsql.read().format("libsvm").option("vectorType", "dense").load(path.getPath()); + DataFrame dataset = sqlContext.read().format("libsvm").option("vectorType", "dense") + .load(path); Assert.assertEquals("label", dataset.columns()[0]); Assert.assertEquals("features", dataset.columns()[1]); Row r = dataset.first(); diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala similarity index 88% rename from mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala rename to mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index 8ed134128c8d2..997f574e51f6a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.ml.source +package org.apache.spark.ml.source.libsvm import java.io.File @@ -23,11 +23,12 @@ import com.google.common.base.Charsets import com.google.common.io.Files import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{SparseVector, Vectors, DenseVector} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.util.Utils class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { + var tempDir: File = _ var path: String = _ override def beforeAll(): Unit = { @@ -38,12 +39,17 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { |0 |0 2:4.0 4:5.0 6:6.0 """.stripMargin - val tempDir = Utils.createTempDir() - val file = new File(tempDir.getPath, "part-00000") + tempDir = Utils.createTempDir() + val file = new File(tempDir, "part-00000") Files.write(lines, file, Charsets.US_ASCII) path = tempDir.toURI.toString } + override def afterAll(): Unit = { + Utils.deleteRecursively(tempDir) + super.afterAll() + } + test("select as sparse vector") { val df = sqlContext.read.format("libsvm").load(path) assert(df.columns(0) == "label") From 2e3a280754a28dc36a71b9ff988e34cbf457f6c3 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 11 Sep 2015 08:55:35 -0700 Subject: [PATCH 254/802] [MINOR] [MLLIB] [ML] [DOC] Minor doc fixes for StringIndexer and MetadataUtils MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changes: * Make Scala doc for StringIndexerInverse clearer. Also remove Scala doc from transformSchema, so that the doc is inherited. * MetadataUtils.scala: “ Helper utilities for tree-based algorithms” —> not just trees anymore CC: holdenk mengxr Author: Joseph K. Bradley Closes #8679 from jkbradley/doc-fixes-1.5. --- .../spark/ml/feature/StringIndexer.scala | 31 +++++++------------ .../apache/spark/ml/util/MetadataUtils.scala | 2 +- python/pyspark/ml/feature.py | 16 +++++----- 3 files changed, 20 insertions(+), 29 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index b6482ffe0b2ee..3a4ab9a857648 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -181,10 +181,10 @@ class StringIndexerModel ( /** * :: Experimental :: - * A [[Transformer]] that maps a column of string indices back to a new column of corresponding - * string values using either the ML attributes of the input column, or if provided using the labels - * supplied by the user. - * All original columns are kept during transformation. + * A [[Transformer]] that maps a column of indices back to a new column of corresponding + * string values. + * The index-string mapping is either from the ML attributes of the input column, + * or from user-supplied labels (which take precedence over ML attributes). * * @see [[StringIndexer]] for converting strings into indices */ @@ -202,32 +202,23 @@ class IndexToString private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - /** - * Optional labels to be provided by the user, if not supplied column - * metadata is read for labels. The default value is an empty array, - * but the empty array is ignored and column metadata used instead. - * @group setParam - */ + /** @group setParam */ def setLabels(value: Array[String]): this.type = set(labels, value) /** - * Param for array of labels. - * Optional labels to be provided by the user. - * Default: Empty array, in which case column metadata is used for labels. + * Optional param for array of labels specifying index-string mapping. + * + * Default: Empty array, in which case [[inputCol]] metadata is used for labels. * @group param */ final val labels: StringArrayParam = new StringArrayParam(this, "labels", - "array of labels, if not provided metadata from inputCol is used instead.") + "Optional array of labels specifying index-string mapping." + + " If not provided or if empty, then metadata from inputCol is used instead.") setDefault(labels, Array.empty[String]) - /** - * Optional labels to be provided by the user, if not supplied column - * metadata is read for labels. - * @group getParam - */ + /** @group getParam */ final def getLabels: Array[String] = $(labels) - /** Transform the schema for the inverse transformation */ override def transformSchema(schema: StructType): StructType = { val inputColName = $(inputCol) val inputDataType = schema(inputColName).dataType diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala index fcb517b5f735e..96a38a3bde960 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.types.StructField /** - * Helper utilities for tree-based algorithms + * Helper utilities for algorithms using ML metadata */ private[spark] object MetadataUtils { diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 71dc636b83eac..97cbee73a05ed 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -985,17 +985,17 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol): """ .. note:: Experimental - A :py:class:`Transformer` that maps a column of string indices back to a new column of - corresponding string values using either the ML attributes of the input column, or if - provided using the labels supplied by the user. - All original columns are kept during transformation. + A :py:class:`Transformer` that maps a column of indices back to a new column of + corresponding string values. + The index-string mapping is either from the ML attributes of the input column, + or from user-supplied labels (which take precedence over ML attributes). See L{StringIndexer} for converting strings into indices. """ # a placeholder to make the labels show up in generated doc labels = Param(Params._dummy(), "labels", - "Optional array of labels to be provided by the user, if not supplied or " + - "empty, column metadata is read for labels") + "Optional array of labels specifying index-string mapping." + + " If not provided or if empty, then metadata from inputCol is used instead.") @keyword_only def __init__(self, inputCol=None, outputCol=None, labels=None): @@ -1006,8 +1006,8 @@ def __init__(self, inputCol=None, outputCol=None, labels=None): self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString", self.uid) self.labels = Param(self, "labels", - "Optional array of labels to be provided by the user, if not " + - "supplied or empty, column metadata is read for labels") + "Optional array of labels specifying index-string mapping. If not" + + " provided or if empty, then metadata from inputCol is used instead.") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) From 6ce0886eb0916a985db142c0b6d2c2b14db5063d Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 11 Sep 2015 09:42:53 -0700 Subject: [PATCH 255/802] [SPARK-10540] [SQL] Ignore HadoopFsRelationTest's "test all data types" if it is too flaky If hadoopFsRelationSuites's "test all data types" is too flaky we can disable it for now. https://issues.apache.org/jira/browse/SPARK-10540 Author: Yin Huai Closes #8705 from yhuai/SPARK-10540-ignore. --- .../org/apache/spark/sql/sources/hadoopFsRelationSuites.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 24f43cf7c15ca..13223c61584b2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -100,7 +100,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } } - test("test all data types") { + ignore("test all data types") { withTempPath { file => // Create the schema. val struct = From 5f46444765a377696af76af6e2c77ab14bfdab8e Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Fri, 11 Sep 2015 10:32:35 -0700 Subject: [PATCH 256/802] [SPARK-8530] [ML] add python API for MinMaxScaler jira: https://issues.apache.org/jira/browse/SPARK-8530 add python API for MinMaxScaler jira for MinMaxScaler: https://issues.apache.org/jira/browse/SPARK-7514 Author: Yuhao Yang Closes #7150 from hhbyyh/pythonMinMax. --- python/pyspark/ml/feature.py | 104 +++++++++++++++++++++++++++++++++-- 1 file changed, 99 insertions(+), 5 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 97cbee73a05ed..92db8df80280b 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -27,11 +27,11 @@ from pyspark.mllib.linalg import _convert_to_vector __all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', - 'IndexToString', 'NGram', 'Normalizer', 'OneHotEncoder', 'PCA', 'PCAModel', - 'PolynomialExpansion', 'RegexTokenizer', 'RFormula', 'RFormulaModel', 'SQLTransformer', - 'StandardScaler', 'StandardScalerModel', 'StopWordsRemover', 'StringIndexer', - 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', - 'Word2Vec', 'Word2VecModel'] + 'IndexToString', 'MinMaxScaler', 'MinMaxScalerModel', 'NGram', 'Normalizer', + 'OneHotEncoder', 'PCA', 'PCAModel', 'PolynomialExpansion', 'RegexTokenizer', + 'RFormula', 'RFormulaModel', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', + 'StopWordsRemover', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', + 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel'] @inherit_doc @@ -406,6 +406,100 @@ class IDFModel(JavaModel): """ +@inherit_doc +class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + Rescale each feature individually to a common range [min, max] linearly using column summary + statistics, which is also known as min-max normalization or Rescaling. The rescaled value for + feature E is calculated as, + + Rescaled(e_i) = (e_i - E_min) / (E_max - E_min) * (max - min) + min + + For the case E_max == E_min, Rescaled(e_i) = 0.5 * (max + min) + + Note that since zero values will probably be transformed to non-zero values, output of the + transformer will be DenseVector even for sparse input. + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)], ["a"]) + >>> mmScaler = MinMaxScaler(inputCol="a", outputCol="scaled") + >>> model = mmScaler.fit(df) + >>> model.transform(df).show() + +-----+------+ + | a|scaled| + +-----+------+ + |[0.0]| [0.0]| + |[2.0]| [1.0]| + +-----+------+ + ... + """ + + # a placeholder to make it appear in the generated doc + min = Param(Params._dummy(), "min", "Lower bound of the output feature range") + max = Param(Params._dummy(), "max", "Upper bound of the output feature range") + + @keyword_only + def __init__(self, min=0.0, max=1.0, inputCol=None, outputCol=None): + """ + __init__(self, min=0.0, max=1.0, inputCol=None, outputCol=None) + """ + super(MinMaxScaler, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MinMaxScaler", self.uid) + self.min = Param(self, "min", "Lower bound of the output feature range") + self.max = Param(self, "max", "Upper bound of the output feature range") + self._setDefault(min=0.0, max=1.0) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, min=0.0, max=1.0, inputCol=None, outputCol=None): + """ + setParams(self, min=0.0, max=1.0, inputCol=None, outputCol=None) + Sets params for this MinMaxScaler. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setMin(self, value): + """ + Sets the value of :py:attr:`min`. + """ + self._paramMap[self.min] = value + return self + + def getMin(self): + """ + Gets the value of min or its default value. + """ + return self.getOrDefault(self.min) + + def setMax(self, value): + """ + Sets the value of :py:attr:`max`. + """ + self._paramMap[self.max] = value + return self + + def getMax(self): + """ + Gets the value of max or its default value. + """ + return self.getOrDefault(self.max) + + def _create_model(self, java_model): + return MinMaxScalerModel(java_model) + + +class MinMaxScalerModel(JavaModel): + """ + .. note:: Experimental + + Model fitted by :py:class:`MinMaxScaler`. + """ + + @inherit_doc @ignore_unicode_prefix class NGram(JavaTransformer, HasInputCol, HasOutputCol): From b231ab8938ae3c4fc2089cfc69c0d8164807d533 Mon Sep 17 00:00:00 2001 From: tedyu Date: Fri, 11 Sep 2015 21:45:45 +0100 Subject: [PATCH 257/802] [SPARK-10546] Check partitionId's range in ExternalSorter#spill() See this thread for background: http://search-hadoop.com/m/q3RTt0rWvIkHAE81 We should check the range of partition Id and provide meaningful message through exception. Alternatively, we can use abs() and modulo to force the partition Id into legitimate range. However, expectation is that user should correct the logic error in his / her code. Author: tedyu Closes #8703 from tedyu/master. --- .../scala/org/apache/spark/util/collection/ExternalSorter.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 138c05dff19e4..31230d5978b2a 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -297,6 +297,8 @@ private[spark] class ExternalSorter[K, V, C]( val it = collection.destructiveSortedWritablePartitionedIterator(comparator) while (it.hasNext) { val partitionId = it.nextPartition() + require(partitionId >= 0 && partitionId < numPartitions, + s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})") it.writeNext(writer) elementsPerPartition(partitionId) += 1 objectsWritten += 1 From c373866774c082885a50daaf7c83f3a14b0cd714 Mon Sep 17 00:00:00 2001 From: Icaro Medeiros Date: Fri, 11 Sep 2015 21:46:52 +0100 Subject: [PATCH 258/802] [PYTHON] Fixed typo in exception message Just fixing a typo in exception message, raised when attempting to pickle SparkContext. Author: Icaro Medeiros Closes #8724 from icaromedeiros/master. --- python/pyspark/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 1b2a52ad64114..a0a1ccbeefb09 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -255,7 +255,7 @@ def __getnewargs__(self): # This method is called when attempting to pickle SparkContext, which is always an error: raise Exception( "It appears that you are attempting to reference SparkContext from a broadcast " - "variable, action, or transforamtion. SparkContext can only be used on the driver, " + "variable, action, or transformation. SparkContext can only be used on the driver, " "not in code that it run on workers. For more information, see SPARK-5063." ) From d5d647380f93f4773f9cb85ea6544892d409b5a1 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 11 Sep 2015 14:15:16 -0700 Subject: [PATCH 259/802] [SPARK-10442] [SQL] fix string to boolean cast When we cast string to boolean in hive, it returns `true` if the length of string is > 0, and spark SQL follows this behavior. However, this behavior is very different from other SQL systems: 1. [presto](https://github.com/facebook/presto/blob/master/presto-main/src/main/java/com/facebook/presto/type/VarcharOperators.java#L89-L118) will return `true` for 't' 'true' '1', `false` for 'f' 'false' '0', throw exception for others. 2. [redshift](http://docs.aws.amazon.com/redshift/latest/dg/r_Boolean_type.html) will return `true` for 't' 'true' 'y' 'yes' '1', `false` for 'f' 'false' 'n' 'no' '0', null for others. 3. [postgresql](http://www.postgresql.org/docs/devel/static/datatype-boolean.html) will return `true` for 't' 'true' 'y' 'yes' 'on' '1', `false` for 'f' 'false' 'n' 'no' 'off' '0', throw exception for others. 4. [vertica](https://my.vertica.com/docs/5.0/HTML/Master/2983.htm) will return `true` for 't' 'true' 'y' 'yes' '1', `false` for 'f' 'false' 'n' 'no' '0', null for others. 5. [impala](http://www.cloudera.com/content/cloudera/en/documentation/cloudera-impala/latest/topics/impala_boolean.html) throw exception when try to cast string to boolean. 6. mysql, oracle, sqlserver don't have boolean type Whether we should change the cast behavior according to other SQL system or not is not decided yet, this PR is a test to see if we changed, how many compatibility tests will fail. Author: Wenchen Fan Closes #8698 from cloud-fan/string2boolean. --- .../spark/sql/catalyst/expressions/Cast.scala | 24 +++++++- .../spark/sql/catalyst/util/StringUtils.scala | 8 +++ .../sql/catalyst/expressions/CastSuite.scala | 61 ++++++++++++------- .../sql/sources/hadoopFsRelationSuites.scala | 13 ++++ 4 files changed, 82 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 2db954257be35..f0bce388d959a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -22,7 +22,7 @@ import java.math.{BigDecimal => JavaBigDecimal} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{StringUtils, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -140,7 +140,15 @@ case class Cast(child: Expression, dataType: DataType) // UDFToBoolean private[this] def castToBoolean(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, _.numBytes() != 0) + buildCast[UTF8String](_, s => { + if (StringUtils.isTrueString(s)) { + true + } else if (StringUtils.isFalseString(s)) { + false + } else { + null + } + }) case TimestampType => buildCast[Long](_, t => t != 0) case DateType => @@ -646,7 +654,17 @@ case class Cast(child: Expression, dataType: DataType) private[this] def castToBooleanCode(from: DataType): CastFunction = from match { case StringType => - (c, evPrim, evNull) => s"$evPrim = $c.numBytes() != 0;" + val stringUtils = StringUtils.getClass.getName.stripSuffix("$") + (c, evPrim, evNull) => + s""" + if ($stringUtils.isTrueString($c)) { + $evPrim = true; + } else if ($stringUtils.isFalseString($c)) { + $evPrim = false; + } else { + $evNull = true; + } + """ case TimestampType => (c, evPrim, evNull) => s"$evPrim = $c != 0;" case DateType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index 9ddfb3a0d3759..c2eeb3c5650ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.util import java.util.regex.Pattern +import org.apache.spark.unsafe.types.UTF8String + object StringUtils { // replace the _ with .{1} exactly match 1 time of any character @@ -44,4 +46,10 @@ object StringUtils { v } } + + private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString) + private[this] val falseStrings = Set("f", "false", "n", "no", "0").map(UTF8String.fromString) + + def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase) + def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 1ad70733eae03..f4db4da7646f8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -503,9 +503,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("cast from array") { - val array = Literal.create(Seq("123", "abc", "", null), + val array = Literal.create(Seq("123", "true", "f", null), ArrayType(StringType, containsNull = true)) - val array_notNull = Literal.create(Seq("123", "abc", ""), + val array_notNull = Literal.create(Seq("123", "true", "f"), ArrayType(StringType, containsNull = false)) checkNullCast(ArrayType(StringType), ArrayType(IntegerType)) @@ -522,7 +522,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { { val ret = cast(array, ArrayType(BooleanType, containsNull = true)) assert(ret.resolved === true) - checkEvaluation(ret, Seq(true, true, false, null)) + checkEvaluation(ret, Seq(null, true, false, null)) } { val ret = cast(array, ArrayType(BooleanType, containsNull = false)) @@ -541,12 +541,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { { val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = true)) assert(ret.resolved === true) - checkEvaluation(ret, Seq(true, true, false)) + checkEvaluation(ret, Seq(null, true, false)) } { val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false)) assert(ret.resolved === true) - checkEvaluation(ret, Seq(true, true, false)) + checkEvaluation(ret, Seq(null, true, false)) } { @@ -557,10 +557,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast from map") { val map = Literal.create( - Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null), + Map("a" -> "123", "b" -> "true", "c" -> "f", "d" -> null), MapType(StringType, StringType, valueContainsNull = true)) val map_notNull = Literal.create( - Map("a" -> "123", "b" -> "abc", "c" -> ""), + Map("a" -> "123", "b" -> "true", "c" -> "f"), MapType(StringType, StringType, valueContainsNull = false)) checkNullCast(MapType(StringType, IntegerType), MapType(StringType, StringType)) @@ -577,7 +577,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { { val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = true)) assert(ret.resolved === true) - checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false, "d" -> null)) + checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false, "d" -> null)) } { val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = false)) @@ -600,12 +600,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { { val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = true)) assert(ret.resolved === true) - checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false)) + checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false)) } { val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false)) assert(ret.resolved === true) - checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false)) + checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false)) } { val ret = cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true)) @@ -630,8 +630,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val struct = Literal.create( InternalRow( UTF8String.fromString("123"), - UTF8String.fromString("abc"), - UTF8String.fromString(""), + UTF8String.fromString("true"), + UTF8String.fromString("f"), null), StructType(Seq( StructField("a", StringType, nullable = true), @@ -641,8 +641,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val struct_notNull = Literal.create( InternalRow( UTF8String.fromString("123"), - UTF8String.fromString("abc"), - UTF8String.fromString("")), + UTF8String.fromString("true"), + UTF8String.fromString("f")), StructType(Seq( StructField("a", StringType, nullable = false), StructField("b", StringType, nullable = false), @@ -672,7 +672,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("c", BooleanType, nullable = true), StructField("d", BooleanType, nullable = true)))) assert(ret.resolved === true) - checkEvaluation(ret, InternalRow(true, true, false, null)) + checkEvaluation(ret, InternalRow(null, true, false, null)) } { val ret = cast(struct, StructType(Seq( @@ -704,7 +704,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("b", BooleanType, nullable = true), StructField("c", BooleanType, nullable = true)))) assert(ret.resolved === true) - checkEvaluation(ret, InternalRow(true, true, false)) + checkEvaluation(ret, InternalRow(null, true, false)) } { val ret = cast(struct_notNull, StructType(Seq( @@ -712,7 +712,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("b", BooleanType, nullable = true), StructField("c", BooleanType, nullable = false)))) assert(ret.resolved === true) - checkEvaluation(ret, InternalRow(true, true, false)) + checkEvaluation(ret, InternalRow(null, true, false)) } { @@ -731,8 +731,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("complex casting") { val complex = Literal.create( Row( - Seq("123", "abc", ""), - Map("a" ->"123", "b" -> "abc", "c" -> ""), + Seq("123", "true", "f"), + Map("a" ->"123", "b" -> "true", "c" -> "f"), Row(0)), StructType(Seq( StructField("a", @@ -755,11 +755,11 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(ret.resolved === true) checkEvaluation(ret, Row( Seq(123, null, null), - Map("a" -> true, "b" -> true, "c" -> false), + Map("a" -> null, "b" -> true, "c" -> false), Row(0L))) } - test("case between string and interval") { + test("cast between string and interval") { import org.apache.spark.unsafe.types.CalendarInterval checkEvaluation(Cast(Literal("interval -3 month 7 hours"), CalendarIntervalType), @@ -769,4 +769,23 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StringType), "interval 1 years 3 months -3 days") } + + test("cast string to boolean") { + checkCast("t", true) + checkCast("true", true) + checkCast("tRUe", true) + checkCast("y", true) + checkCast("yes", true) + checkCast("1", true) + + checkCast("f", false) + checkCast("false", false) + checkCast("FAlsE", false) + checkCast("n", false) + checkCast("no", false) + checkCast("0", false) + + checkEvaluation(cast("abc", BooleanType), null) + checkEvaluation(cast("", BooleanType), null) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 13223c61584b2..8ffcef85668d6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -375,6 +375,19 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } } + test("saveAsTable()/load() - partitioned table - boolean type") { + sqlContext.range(2) + .select('id, ('id % 2 === 0).as("b")) + .write.partitionBy("b").saveAsTable("t") + + withTable("t") { + checkAnswer( + sqlContext.table("t").sort('id), + Row(0, true) :: Row(1, false) :: Nil + ) + } + } + test("saveAsTable()/load() - partitioned table - Overwrite") { partitionedTestDF.write .format(dataSourceName) From 1eede3b254ee3793841c92971707094ac8afee35 Mon Sep 17 00:00:00 2001 From: Yash Datta Date: Fri, 11 Sep 2015 14:55:15 -0700 Subject: [PATCH 260/802] [SPARK-7142] [SQL] Minor enhancement to BooleanSimplification Optimizer rule. Incorporate review comments Adding changes suggested by cloud-fan in #5700 cc marmbrus Author: Yash Datta Closes #8716 from saucam/bool_simp. --- .../apache/spark/sql/catalyst/optimizer/Optimizer.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d9b50f3c97da0..0f4caec7451a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -435,10 +435,10 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { // a && a => a case (l, r) if l fastEquals r => l // a && (not(a) || b) => a && b - case (l, Or(l1, r)) if (Not(l) fastEquals l1) => And(l, r) - case (l, Or(r, l1)) if (Not(l) fastEquals l1) => And(l, r) - case (Or(l, l1), r) if (l1 fastEquals Not(r)) => And(l, r) - case (Or(l1, l), r) if (l1 fastEquals Not(r)) => And(l, r) + case (l, Or(l1, r)) if (Not(l) == l1) => And(l, r) + case (l, Or(r, l1)) if (Not(l) == l1) => And(l, r) + case (Or(l, l1), r) if (l1 == Not(r)) => And(l, r) + case (Or(l1, l), r) if (l1 == Not(r)) => And(l, r) // (a || b) && (a || c) => a || (b && c) case _ => // 1. Split left and right to get the disjunctive predicates, From e626ac5f5c27dcc74113070f2fec03682bcd12bd Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 11 Sep 2015 15:00:13 -0700 Subject: [PATCH 261/802] [SPARK-9992] [SPARK-9994] [SPARK-9998] [SQL] Implement the local TopK, sample and intersect operators This PR is in conflict with #8535. I will update this one when #8535 gets merged. Author: zsxwing Closes #8573 from zsxwing/more-local-operators. --- .../spark/sql/execution/basicOperators.scala | 2 +- .../sql/execution/local/IntersectNode.scala | 63 ++++++++++++++ .../spark/sql/execution/local/LocalNode.scala | 5 ++ .../sql/execution/local/SampleNode.scala | 82 +++++++++++++++++++ .../local/TakeOrderedAndProjectNode.scala | 73 +++++++++++++++++ .../execution/local/IntersectNodeSuite.scala | 35 ++++++++ .../sql/execution/local/SampleNodeSuite.scala | 40 +++++++++ .../TakeOrderedAndProjectNodeSuite.scala | 54 ++++++++++++ 8 files changed, 353 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 3f68b05a24f44..bf6d44c098ee3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -138,7 +138,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { * will be ub - lb. * @param withReplacement Whether to sample with replacement. * @param seed the random seed - * @param child the QueryPlan + * @param child the SparkPlan */ @DeveloperApi case class Sample( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala new file mode 100644 index 0000000000000..740d485f8d9e6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/IntersectNode.scala @@ -0,0 +1,63 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import scala.collection.mutable + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class IntersectNode(conf: SQLConf, left: LocalNode, right: LocalNode) + extends BinaryLocalNode(conf) { + + override def output: Seq[Attribute] = left.output + + private[this] var leftRows: mutable.HashSet[InternalRow] = _ + + private[this] var currentRow: InternalRow = _ + + override def open(): Unit = { + left.open() + leftRows = mutable.HashSet[InternalRow]() + while (left.next()) { + leftRows += left.fetch().copy() + } + left.close() + right.open() + } + + override def next(): Boolean = { + currentRow = null + while (currentRow == null && right.next()) { + currentRow = right.fetch() + if (!leftRows.contains(currentRow)) { + currentRow = null + } + } + currentRow != null + } + + override def fetch(): InternalRow = currentRow + + override def close(): Unit = { + left.close() + right.close() + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index c4f8ae304db39..a2c275db9b35d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -69,6 +69,11 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging */ def close(): Unit + /** + * Returns the content through the [[Iterator]] interface. + */ + final def asIterator: Iterator[InternalRow] = new LocalNodeIterator(this) + /** * Returns the content of the iterator from the beginning to the end in the form of a Scala Seq. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala new file mode 100644 index 0000000000000..abf3df1c0c2af --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import java.util.Random + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} + +/** + * Sample the dataset. + * + * @param conf the SQLConf + * @param lowerBound Lower-bound of the sampling probability (usually 0.0) + * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled + * will be ub - lb. + * @param withReplacement Whether to sample with replacement. + * @param seed the random seed + * @param child the LocalNode + */ +case class SampleNode( + conf: SQLConf, + lowerBound: Double, + upperBound: Double, + withReplacement: Boolean, + seed: Long, + child: LocalNode) extends UnaryLocalNode(conf) { + + override def output: Seq[Attribute] = child.output + + private[this] var iterator: Iterator[InternalRow] = _ + + private[this] var currentRow: InternalRow = _ + + override def open(): Unit = { + child.open() + val (sampler, _seed) = if (withReplacement) { + val random = new Random(seed) + // Disable gap sampling since the gap sampling method buffers two rows internally, + // requiring us to copy the row, which is more expensive than the random number generator. + (new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false), + // Use the seed for partition 0 like PartitionwiseSampledRDD to generate the same result + // of DataFrame + random.nextLong()) + } else { + (new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed) + } + sampler.setSeed(_seed) + iterator = sampler.sample(child.asIterator) + } + + override def next(): Boolean = { + if (iterator.hasNext) { + currentRow = iterator.next() + true + } else { + false + } + } + + override def fetch(): InternalRow = currentRow + + override def close(): Unit = child.close() + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala new file mode 100644 index 0000000000000..53f1dcc65d8cf --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.util.BoundedPriorityQueue + +case class TakeOrderedAndProjectNode( + conf: SQLConf, + limit: Int, + sortOrder: Seq[SortOrder], + projectList: Option[Seq[NamedExpression]], + child: LocalNode) extends UnaryLocalNode(conf) { + + private[this] var projection: Option[Projection] = _ + private[this] var ord: InterpretedOrdering = _ + private[this] var iterator: Iterator[InternalRow] = _ + private[this] var currentRow: InternalRow = _ + + override def output: Seq[Attribute] = { + val projectOutput = projectList.map(_.map(_.toAttribute)) + projectOutput.getOrElse(child.output) + } + + override def open(): Unit = { + child.open() + projection = projectList.map(new InterpretedProjection(_, child.output)) + ord = new InterpretedOrdering(sortOrder, child.output) + // Priority keeps the largest elements, so let's reverse the ordering. + val queue = new BoundedPriorityQueue[InternalRow](limit)(ord.reverse) + while (child.next()) { + queue += child.fetch() + } + // Close it eagerly since we don't need it. + child.close() + iterator = queue.iterator + } + + override def next(): Boolean = { + if (iterator.hasNext) { + val _currentRow = iterator.next() + currentRow = projection match { + case Some(p) => p(_currentRow) + case None => _currentRow + } + true + } else { + false + } + } + + override def fetch(): InternalRow = currentRow + + override def close(): Unit = child.close() + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala new file mode 100644 index 0000000000000..7deaa375fcfc2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala @@ -0,0 +1,35 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +class IntersectNodeSuite extends LocalNodeTest { + + import testImplicits._ + + test("basic") { + val input1 = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") + val input2 = (1 to 10).filter(_ % 2 == 0).map(i => (i, i.toString)).toDF("key", "value") + + checkAnswer2( + input1, + input2, + (node1, node2) => IntersectNode(conf, node1, node2), + input1.intersect(input2).collect() + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala new file mode 100644 index 0000000000000..87a7da453999c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +class SampleNodeSuite extends LocalNodeTest { + + import testImplicits._ + + private def testSample(withReplacement: Boolean): Unit = { + test(s"withReplacement: $withReplacement") { + val seed = 0L + val input = sqlContext.sparkContext. + parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition + toDF("key", "value") + checkAnswer( + input, + node => SampleNode(conf, 0.0, 0.3, withReplacement, seed, node), + input.sample(withReplacement, 0.3, seed).collect() + ) + } + } + + testSample(withReplacement = true) + testSample(withReplacement = false) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala new file mode 100644 index 0000000000000..ff28b24eeff14 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, SortOrder} + +class TakeOrderedAndProjectNodeSuite extends LocalNodeTest { + + import testImplicits._ + + private def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = { + val sortOrder: Seq[SortOrder] = sortExprs.map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + sortOrder + } + + private def testTakeOrderedAndProjectNode(desc: Boolean): Unit = { + val testCaseName = if (desc) "desc" else "asc" + test(testCaseName) { + val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") + val sortColumn = if (desc) input.col("key").desc else input.col("key") + checkAnswer( + input, + node => TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(sortColumn), None, node), + input.sort(sortColumn).limit(5).collect() + ) + } + } + + testTakeOrderedAndProjectNode(desc = false) + testTakeOrderedAndProjectNode(desc = true) +} From c2af42b5f32287ff595ad027a8191d4b75702d8d Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 11 Sep 2015 15:01:37 -0700 Subject: [PATCH 262/802] [SPARK-9990] [SQL] Local hash join follow-ups 1. Hide `LocalNodeIterator` behind the `LocalNode#asIterator` method 2. Add tests for this Author: Andrew Or Closes #8708 from andrewor14/local-hash-join-follow-up. --- .../sql/execution/joins/HashedRelation.scala | 7 +- .../sql/execution/local/HashJoinNode.scala | 3 +- .../spark/sql/execution/local/LocalNode.scala | 4 +- .../sql/execution/local/LocalNodeSuite.scala | 116 ++++++++++++++++++ 4 files changed, 125 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 0cff21ca618b4..bc255b27502b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -25,7 +25,8 @@ import org.apache.spark.shuffle.ShuffleMemoryManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer -import org.apache.spark.sql.execution.metric.LongSQLMetric +import org.apache.spark.sql.execution.local.LocalNode +import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.memory.{MemoryLocation, ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} @@ -113,6 +114,10 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR private[execution] object HashedRelation { + def apply(localNode: LocalNode, keyGenerator: Projection): HashedRelation = { + apply(localNode.asIterator, SQLMetrics.nullLongMetric, keyGenerator) + } + def apply( input: Iterator[InternalRow], numInputRows: LongSQLMetric, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala index a3e68d6a7c341..e7b24e3fca2b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala @@ -75,8 +75,7 @@ case class HashJoinNode( override def open(): Unit = { buildNode.open() - hashed = HashedRelation.apply( - new LocalNodeIterator(buildNode), SQLMetrics.nullLongMetric, buildSideKeyGenerator) + hashed = HashedRelation(buildNode, buildSideKeyGenerator) streamedNode.open() joinRow = new JoinedRow resultProjection = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index a2c275db9b35d..e540ef8555eb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -77,7 +77,7 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging /** * Returns the content of the iterator from the beginning to the end in the form of a Scala Seq. */ - def collect(): Seq[Row] = { + final def collect(): Seq[Row] = { val converter = CatalystTypeConverters.createToScalaConverter(StructType.fromAttributes(output)) val result = new scala.collection.mutable.ArrayBuffer[Row] open() @@ -140,7 +140,7 @@ abstract class BinaryLocalNode(conf: SQLConf) extends LocalNode(conf) { /** * An thin wrapper around a [[LocalNode]] that provides an `Iterator` interface. */ -private[local] class LocalNodeIterator(localNode: LocalNode) extends Iterator[InternalRow] { +private class LocalNodeIterator(localNode: LocalNode) extends Iterator[InternalRow] { private var nextRow: InternalRow = _ override def hasNext: Boolean = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala new file mode 100644 index 0000000000000..b89fa46f8b3b4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala @@ -0,0 +1,116 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.IntegerType + +class LocalNodeSuite extends SparkFunSuite { + private val data = (1 to 100).toArray + + test("basic open, next, fetch, close") { + val node = new DummyLocalNode(data) + assert(!node.isOpen) + node.open() + assert(node.isOpen) + data.foreach { i => + assert(node.next()) + // fetch should be idempotent + val fetched = node.fetch() + assert(node.fetch() === fetched) + assert(node.fetch() === fetched) + assert(node.fetch().numFields === 1) + assert(node.fetch().getInt(0) === i) + } + assert(!node.next()) + node.close() + assert(!node.isOpen) + } + + test("asIterator") { + val node = new DummyLocalNode(data) + val iter = node.asIterator + node.open() + data.foreach { i => + // hasNext should be idempotent + assert(iter.hasNext) + assert(iter.hasNext) + val item = iter.next() + assert(item.numFields === 1) + assert(item.getInt(0) === i) + } + intercept[NoSuchElementException] { + iter.next() + } + node.close() + } + + test("collect") { + val node = new DummyLocalNode(data) + node.open() + val collected = node.collect() + assert(collected.size === data.size) + assert(collected.forall(_.size === 1)) + assert(collected.map(_.getInt(0)) === data) + node.close() + } + +} + +/** + * A dummy [[LocalNode]] that just returns one row per integer in the input. + */ +private case class DummyLocalNode(conf: SQLConf, input: Array[Int]) extends LocalNode(conf) { + private var index = Int.MinValue + + def this(input: Array[Int]) { + this(new SQLConf, input) + } + + def isOpen: Boolean = { + index != Int.MinValue + } + + override def output: Seq[Attribute] = { + Seq(AttributeReference("something", IntegerType)()) + } + + override def children: Seq[LocalNode] = Seq.empty + + override def open(): Unit = { + index = -1 + } + + override def next(): Boolean = { + index += 1 + index < input.size + } + + override def fetch(): InternalRow = { + assert(index >= 0 && index < input.size) + val values = Array(input(index).asInstanceOf[Any]) + new GenericInternalRow(values) + } + + override def close(): Unit = { + index = Int.MinValue + } +} From d74c6a143cbd060c25bf14a8d306841b3ec55d03 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 11 Sep 2015 15:02:59 -0700 Subject: [PATCH 263/802] [SPARK-10564] ThreadingSuite: assertion failures in threads don't fail the test This commit ensures if an assertion fails within a thread, it will ultimately fail the test. Otherwise we end up potentially masking real bugs by not propagating assertion failures properly. Author: Andrew Or Closes #8723 from andrewor14/fix-threading-suite. --- .../org/apache/spark/ThreadingSuite.scala | 68 ++++++++++++------- 1 file changed, 45 insertions(+), 23 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala index 48509f0759a3b..cda2b245526f7 100644 --- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -119,23 +119,30 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { val nums = sc.parallelize(1 to 2, 2) val sem = new Semaphore(0) ThreadingSuiteState.clear() + var throwable: Option[Throwable] = None for (i <- 0 until 2) { new Thread { override def run() { - val ans = nums.map(number => { - val running = ThreadingSuiteState.runningThreads - running.getAndIncrement() - val time = System.currentTimeMillis() - while (running.get() != 4 && System.currentTimeMillis() < time + 1000) { - Thread.sleep(100) - } - if (running.get() != 4) { - ThreadingSuiteState.failed.set(true) - } - number - }).collect() - assert(ans.toList === List(1, 2)) - sem.release() + try { + val ans = nums.map(number => { + val running = ThreadingSuiteState.runningThreads + running.getAndIncrement() + val time = System.currentTimeMillis() + while (running.get() != 4 && System.currentTimeMillis() < time + 1000) { + Thread.sleep(100) + } + if (running.get() != 4) { + ThreadingSuiteState.failed.set(true) + } + number + }).collect() + assert(ans.toList === List(1, 2)) + } catch { + case t: Throwable => + throwable = Some(t) + } finally { + sem.release() + } } }.start() } @@ -145,18 +152,25 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { ThreadingSuiteState.runningThreads.get() + "); failing test") fail("One or more threads didn't see runningThreads = 4") } + throwable.foreach { t => throw t } } test("set local properties in different thread") { sc = new SparkContext("local", "test") val sem = new Semaphore(0) - + var throwable: Option[Throwable] = None val threads = (1 to 5).map { i => new Thread() { override def run() { - sc.setLocalProperty("test", i.toString) - assert(sc.getLocalProperty("test") === i.toString) - sem.release() + try { + sc.setLocalProperty("test", i.toString) + assert(sc.getLocalProperty("test") === i.toString) + } catch { + case t: Throwable => + throwable = Some(t) + } finally { + sem.release() + } } } } @@ -165,20 +179,27 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { sem.acquire(5) assert(sc.getLocalProperty("test") === null) + throwable.foreach { t => throw t } } test("set and get local properties in parent-children thread") { sc = new SparkContext("local", "test") sc.setLocalProperty("test", "parent") val sem = new Semaphore(0) - + var throwable: Option[Throwable] = None val threads = (1 to 5).map { i => new Thread() { override def run() { - assert(sc.getLocalProperty("test") === "parent") - sc.setLocalProperty("test", i.toString) - assert(sc.getLocalProperty("test") === i.toString) - sem.release() + try { + assert(sc.getLocalProperty("test") === "parent") + sc.setLocalProperty("test", i.toString) + assert(sc.getLocalProperty("test") === i.toString) + } catch { + case t: Throwable => + throwable = Some(t) + } finally { + sem.release() + } } } } @@ -188,6 +209,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { sem.acquire(5) assert(sc.getLocalProperty("test") === "parent") assert(sc.getLocalProperty("Foo") === null) + throwable.foreach { t => throw t } } test("mutations to local properties should not affect submitted jobs (SPARK-6629)") { From c34fc19765bdf55365cdce78d9ba11b220b73bb6 Mon Sep 17 00:00:00 2001 From: 0x0FFF Date: Fri, 11 Sep 2015 15:19:04 -0700 Subject: [PATCH 264/802] [SPARK-9014] [SQL] Allow Python spark API to use built-in exponential operator This PR addresses (SPARK-9014)[https://issues.apache.org/jira/browse/SPARK-9014] Added functionality: `Column` object in Python now supports exponential operator `**` Example: ``` from pyspark.sql import * df = sqlContext.createDataFrame([Row(a=2)]) df.select(3**df.a,df.a**3,df.a**df.a).collect() ``` Outputs: ``` [Row(POWER(3.0, a)=9.0, POWER(a, 3.0)=8.0, POWER(a, a)=4.0)] ``` Author: 0x0FFF Closes #8658 from 0x0FFF/SPARK-9014. --- python/pyspark/sql/column.py | 13 +++++++++++++ python/pyspark/sql/tests.py | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 573f65f5bf096..9ca8e1f264cfa 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -91,6 +91,17 @@ def _(self): return _ +def _bin_func_op(name, reverse=False, doc="binary function"): + def _(self, other): + sc = SparkContext._active_spark_context + fn = getattr(sc._jvm.functions, name) + jc = other._jc if isinstance(other, Column) else _create_column_from_literal(other) + njc = fn(self._jc, jc) if not reverse else fn(jc, self._jc) + return Column(njc) + _.__doc__ = doc + return _ + + def _bin_op(name, doc="binary operator"): """ Create a method for given binary operator """ @@ -151,6 +162,8 @@ def __init__(self, jc): __rdiv__ = _reverse_op("divide") __rtruediv__ = _reverse_op("divide") __rmod__ = _reverse_op("mod") + __pow__ = _bin_func_op("pow") + __rpow__ = _bin_func_op("pow", reverse=True) # logistic operators __eq__ = _bin_op("equalTo") diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index eb449e8679fa0..f2172b7a27d88 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -568,7 +568,7 @@ def test_column_operators(self): cs = self.df.value c = ci == cs self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) - rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) + rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci), (1 ** ci), (ci ** 1) self.assertTrue(all(isinstance(c, Column) for c in rcc)) cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7] self.assertTrue(all(isinstance(c, Column) for c in cb)) From 6d8367807cb62c2cb139cee1d039dc8b12c63385 Mon Sep 17 00:00:00 2001 From: Daniel Imfeld Date: Sat, 12 Sep 2015 09:19:59 +0100 Subject: [PATCH 265/802] [SPARK-10566] [CORE] SnappyCompressionCodec init exception handling masks important error information When throwing an IllegalArgumentException in SnappyCompressionCodec.init, chain the existing exception. This allows potentially important debugging info to be passed to the user. Manual testing shows the exception chained properly, and the test suite still looks fine as well. This contribution is my original work and I license the work to the project under the project's open source license. Author: Daniel Imfeld Closes #8725 from dimfeld/dimfeld-patch-1. --- core/src/main/scala/org/apache/spark/io/CompressionCodec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 607d5a321efca..9dc36704a676d 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -148,7 +148,7 @@ class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { try { Snappy.getNativeLibraryVersion } catch { - case e: Error => throw new IllegalArgumentException + case e: Error => throw new IllegalArgumentException(e) } override def compressedOutputStream(s: OutputStream): OutputStream = { From 8285e3b0d3dc0eff669eba993742dfe0401116f9 Mon Sep 17 00:00:00 2001 From: Nithin Asokan Date: Sat, 12 Sep 2015 09:50:49 +0100 Subject: [PATCH 266/802] [SPARK-10554] [CORE] Fix NPE with ShutdownHook https://issues.apache.org/jira/browse/SPARK-10554 Fixes NPE when ShutdownHook tries to cleanup temporary folders Author: Nithin Asokan Closes #8720 from nasokan/SPARK-10554. --- .../scala/org/apache/spark/storage/DiskBlockManager.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 3f8d26e1d4cab..f7e84a2c2e14c 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -164,7 +164,9 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon private def doStop(): Unit = { // Only perform cleanup if an external service is not serving our shuffle files. - if (!blockManager.externalShuffleServiceEnabled || blockManager.blockManagerId.isDriver) { + // Also blockManagerId could be null if block manager is not initialized properly. + if (!blockManager.externalShuffleServiceEnabled || + (blockManager.blockManagerId != null && blockManager.blockManagerId.isDriver)) { localDirs.foreach { localDir => if (localDir.isDirectory() && localDir.exists()) { try { From 22730ad54d681ad30e63fe910e8d89360853177d Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 12 Sep 2015 10:40:10 +0100 Subject: [PATCH 267/802] [SPARK-10547] [TEST] Streamline / improve style of Java API tests Fix a few Java API test style issues: unused generic types, exceptions, wrong assert argument order Author: Sean Owen Closes #8706 from srowen/SPARK-10547. --- .../java/org/apache/spark/JavaAPISuite.java | 451 ++++++----- .../kafka/JavaDirectKafkaStreamSuite.java | 24 +- .../streaming/kafka/JavaKafkaRDDSuite.java | 17 +- .../streaming/kafka/JavaKafkaStreamSuite.java | 14 +- .../twitter/JavaTwitterStreamSuite.java | 4 +- .../java/org/apache/spark/Java8APISuite.java | 46 +- .../spark/sql/JavaApplySchemaSuite.java | 39 +- .../apache/spark/sql/JavaDataFrameSuite.java | 29 +- .../org/apache/spark/sql/JavaRowSuite.java | 15 +- .../org/apache/spark/sql/JavaUDFSuite.java | 9 +- .../spark/sql/sources/JavaSaveLoadSuite.java | 10 +- .../spark/sql/hive/JavaDataFrameSuite.java | 8 +- .../hive/JavaMetastoreDataSourcesSuite.java | 12 +- .../apache/spark/streaming/JavaAPISuite.java | 752 +++++++++--------- .../spark/streaming/JavaReceiverAPISuite.java | 86 +- 15 files changed, 755 insertions(+), 761 deletions(-) diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index ebd3d61ae7324..fd8f7f39b7cc8 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -90,7 +90,7 @@ public void sparkContextUnion() { JavaRDD sUnion = sc.union(s1, s2); Assert.assertEquals(4, sUnion.count()); // List - List> list = new ArrayList>(); + List> list = new ArrayList<>(); list.add(s2); sUnion = sc.union(s1, list); Assert.assertEquals(4, sUnion.count()); @@ -103,9 +103,9 @@ public void sparkContextUnion() { Assert.assertEquals(4, dUnion.count()); // Union of JavaPairRDDs - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(1, 2)); - pairs.add(new Tuple2(3, 4)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(1, 2)); + pairs.add(new Tuple2<>(3, 4)); JavaPairRDD p1 = sc.parallelizePairs(pairs); JavaPairRDD p2 = sc.parallelizePairs(pairs); JavaPairRDD pUnion = sc.union(p1, p2); @@ -133,9 +133,9 @@ public void intersection() { JavaDoubleRDD dIntersection = d1.intersection(d2); Assert.assertEquals(2, dIntersection.count()); - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(1, 2)); - pairs.add(new Tuple2(3, 4)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(1, 2)); + pairs.add(new Tuple2<>(3, 4)); JavaPairRDD p1 = sc.parallelizePairs(pairs); JavaPairRDD p2 = sc.parallelizePairs(pairs); JavaPairRDD pIntersection = p1.intersection(p2); @@ -165,47 +165,49 @@ public void randomSplit() { @Test public void sortByKey() { - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(0, 4)); - pairs.add(new Tuple2(3, 2)); - pairs.add(new Tuple2(-1, 1)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(0, 4)); + pairs.add(new Tuple2<>(3, 2)); + pairs.add(new Tuple2<>(-1, 1)); JavaPairRDD rdd = sc.parallelizePairs(pairs); // Default comparator JavaPairRDD sortedRDD = rdd.sortByKey(); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); List> sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); + Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); // Custom comparator sortedRDD = rdd.sortByKey(Collections.reverseOrder(), false); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); + Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); } @SuppressWarnings("unchecked") @Test public void repartitionAndSortWithinPartitions() { - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(0, 5)); - pairs.add(new Tuple2(3, 8)); - pairs.add(new Tuple2(2, 6)); - pairs.add(new Tuple2(0, 8)); - pairs.add(new Tuple2(3, 8)); - pairs.add(new Tuple2(1, 3)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(0, 5)); + pairs.add(new Tuple2<>(3, 8)); + pairs.add(new Tuple2<>(2, 6)); + pairs.add(new Tuple2<>(0, 8)); + pairs.add(new Tuple2<>(3, 8)); + pairs.add(new Tuple2<>(1, 3)); JavaPairRDD rdd = sc.parallelizePairs(pairs); Partitioner partitioner = new Partitioner() { + @Override public int numPartitions() { return 2; } + @Override public int getPartition(Object key) { - return ((Integer)key).intValue() % 2; + return (Integer) key % 2; } }; @@ -214,10 +216,10 @@ public int getPartition(Object key) { Assert.assertTrue(repartitioned.partitioner().isPresent()); Assert.assertEquals(repartitioned.partitioner().get(), partitioner); List>> partitions = repartitioned.glom().collect(); - Assert.assertEquals(partitions.get(0), Arrays.asList(new Tuple2(0, 5), - new Tuple2(0, 8), new Tuple2(2, 6))); - Assert.assertEquals(partitions.get(1), Arrays.asList(new Tuple2(1, 3), - new Tuple2(3, 8), new Tuple2(3, 8))); + Assert.assertEquals(partitions.get(0), + Arrays.asList(new Tuple2<>(0, 5), new Tuple2<>(0, 8), new Tuple2<>(2, 6))); + Assert.assertEquals(partitions.get(1), + Arrays.asList(new Tuple2<>(1, 3), new Tuple2<>(3, 8), new Tuple2<>(3, 8))); } @Test @@ -228,35 +230,37 @@ public void emptyRDD() { @Test public void sortBy() { - List> pairs = new ArrayList>(); - pairs.add(new Tuple2(0, 4)); - pairs.add(new Tuple2(3, 2)); - pairs.add(new Tuple2(-1, 1)); + List> pairs = new ArrayList<>(); + pairs.add(new Tuple2<>(0, 4)); + pairs.add(new Tuple2<>(3, 2)); + pairs.add(new Tuple2<>(-1, 1)); JavaRDD> rdd = sc.parallelize(pairs); // compare on first value JavaRDD> sortedRDD = rdd.sortBy(new Function, Integer>() { - public Integer call(Tuple2 t) throws Exception { + @Override + public Integer call(Tuple2 t) { return t._1(); } }, true, 2); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); List> sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); + Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); // compare on second value sortedRDD = rdd.sortBy(new Function, Integer>() { - public Integer call(Tuple2 t) throws Exception { + @Override + public Integer call(Tuple2 t) { return t._2(); } }, true, 2); - Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); + Assert.assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); sortedPairs = sortedRDD.collect(); - Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(1)); - Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(2)); + Assert.assertEquals(new Tuple2<>(3, 2), sortedPairs.get(1)); + Assert.assertEquals(new Tuple2<>(0, 4), sortedPairs.get(2)); } @Test @@ -265,7 +269,7 @@ public void foreach() { JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); rdd.foreach(new VoidFunction() { @Override - public void call(String s) throws IOException { + public void call(String s) { accum.add(1); } }); @@ -278,7 +282,7 @@ public void foreachPartition() { JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); rdd.foreachPartition(new VoidFunction>() { @Override - public void call(Iterator iter) throws IOException { + public void call(Iterator iter) { while (iter.hasNext()) { iter.next(); accum.add(1); @@ -301,7 +305,7 @@ public void zipWithUniqueId() { List dataArray = Arrays.asList(1, 2, 3, 4); JavaPairRDD zip = sc.parallelize(dataArray).zipWithUniqueId(); JavaRDD indexes = zip.values(); - Assert.assertEquals(4, new HashSet(indexes.collect()).size()); + Assert.assertEquals(4, new HashSet<>(indexes.collect()).size()); } @Test @@ -317,10 +321,10 @@ public void zipWithIndex() { @Test public void lookup() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") - )); + new Tuple2<>("Apples", "Fruit"), + new Tuple2<>("Oranges", "Fruit"), + new Tuple2<>("Oranges", "Citrus") + )); Assert.assertEquals(2, categories.lookup("Oranges").size()); Assert.assertEquals(2, Iterables.size(categories.groupByKey().lookup("Oranges").get(0))); } @@ -390,18 +394,17 @@ public String call(Tuple2 x) { @Test public void cogroup() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") + new Tuple2<>("Apples", "Fruit"), + new Tuple2<>("Oranges", "Fruit"), + new Tuple2<>("Oranges", "Citrus") )); JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 2), - new Tuple2("Apples", 3) + new Tuple2<>("Oranges", 2), + new Tuple2<>("Apples", 3) )); JavaPairRDD, Iterable>> cogrouped = categories.cogroup(prices); - Assert.assertEquals("[Fruit, Citrus]", - Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); cogrouped.collect(); @@ -411,23 +414,22 @@ public void cogroup() { @Test public void cogroup3() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") + new Tuple2<>("Apples", "Fruit"), + new Tuple2<>("Oranges", "Fruit"), + new Tuple2<>("Oranges", "Citrus") )); JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 2), - new Tuple2("Apples", 3) + new Tuple2<>("Oranges", 2), + new Tuple2<>("Apples", 3) )); JavaPairRDD quantities = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 21), - new Tuple2("Apples", 42) + new Tuple2<>("Oranges", 21), + new Tuple2<>("Apples", 42) )); JavaPairRDD, Iterable, Iterable>> cogrouped = categories.cogroup(prices, quantities); - Assert.assertEquals("[Fruit, Citrus]", - Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3())); @@ -439,27 +441,26 @@ public void cogroup3() { @Test public void cogroup4() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( - new Tuple2("Apples", "Fruit"), - new Tuple2("Oranges", "Fruit"), - new Tuple2("Oranges", "Citrus") + new Tuple2<>("Apples", "Fruit"), + new Tuple2<>("Oranges", "Fruit"), + new Tuple2<>("Oranges", "Citrus") )); JavaPairRDD prices = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 2), - new Tuple2("Apples", 3) + new Tuple2<>("Oranges", 2), + new Tuple2<>("Apples", 3) )); JavaPairRDD quantities = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", 21), - new Tuple2("Apples", 42) + new Tuple2<>("Oranges", 21), + new Tuple2<>("Apples", 42) )); JavaPairRDD countries = sc.parallelizePairs(Arrays.asList( - new Tuple2("Oranges", "BR"), - new Tuple2("Apples", "US") + new Tuple2<>("Oranges", "BR"), + new Tuple2<>("Apples", "US") )); JavaPairRDD, Iterable, Iterable, Iterable>> cogrouped = categories.cogroup(prices, quantities, countries); - Assert.assertEquals("[Fruit, Citrus]", - Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); Assert.assertEquals("[42]", Iterables.toString(cogrouped.lookup("Apples").get(0)._3())); Assert.assertEquals("[BR]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._4())); @@ -471,16 +472,16 @@ public void cogroup4() { @Test public void leftOuterJoin() { JavaPairRDD rdd1 = sc.parallelizePairs(Arrays.asList( - new Tuple2(1, 1), - new Tuple2(1, 2), - new Tuple2(2, 1), - new Tuple2(3, 1) + new Tuple2<>(1, 1), + new Tuple2<>(1, 2), + new Tuple2<>(2, 1), + new Tuple2<>(3, 1) )); JavaPairRDD rdd2 = sc.parallelizePairs(Arrays.asList( - new Tuple2(1, 'x'), - new Tuple2(2, 'y'), - new Tuple2(2, 'z'), - new Tuple2(4, 'w') + new Tuple2<>(1, 'x'), + new Tuple2<>(2, 'y'), + new Tuple2<>(2, 'z'), + new Tuple2<>(4, 'w') )); List>>> joined = rdd1.leftOuterJoin(rdd2).collect(); @@ -548,11 +549,11 @@ public Integer call(Integer a, Integer b) { public void aggregateByKey() { JavaPairRDD pairs = sc.parallelizePairs( Arrays.asList( - new Tuple2(1, 1), - new Tuple2(1, 1), - new Tuple2(3, 2), - new Tuple2(5, 1), - new Tuple2(5, 3)), 2); + new Tuple2<>(1, 1), + new Tuple2<>(1, 1), + new Tuple2<>(3, 2), + new Tuple2<>(5, 1), + new Tuple2<>(5, 3)), 2); Map> sets = pairs.aggregateByKey(new HashSet(), new Function2, Integer, Set>() { @@ -570,20 +571,20 @@ public Set call(Set a, Set b) { } }).collectAsMap(); Assert.assertEquals(3, sets.size()); - Assert.assertEquals(new HashSet(Arrays.asList(1)), sets.get(1)); - Assert.assertEquals(new HashSet(Arrays.asList(2)), sets.get(3)); - Assert.assertEquals(new HashSet(Arrays.asList(1, 3)), sets.get(5)); + Assert.assertEquals(new HashSet<>(Arrays.asList(1)), sets.get(1)); + Assert.assertEquals(new HashSet<>(Arrays.asList(2)), sets.get(3)); + Assert.assertEquals(new HashSet<>(Arrays.asList(1, 3)), sets.get(5)); } @SuppressWarnings("unchecked") @Test public void foldByKey() { List> pairs = Arrays.asList( - new Tuple2(2, 1), - new Tuple2(2, 1), - new Tuple2(1, 1), - new Tuple2(3, 2), - new Tuple2(3, 1) + new Tuple2<>(2, 1), + new Tuple2<>(2, 1), + new Tuple2<>(1, 1), + new Tuple2<>(3, 2), + new Tuple2<>(3, 1) ); JavaPairRDD rdd = sc.parallelizePairs(pairs); JavaPairRDD sums = rdd.foldByKey(0, @@ -602,11 +603,11 @@ public Integer call(Integer a, Integer b) { @Test public void reduceByKey() { List> pairs = Arrays.asList( - new Tuple2(2, 1), - new Tuple2(2, 1), - new Tuple2(1, 1), - new Tuple2(3, 2), - new Tuple2(3, 1) + new Tuple2<>(2, 1), + new Tuple2<>(2, 1), + new Tuple2<>(1, 1), + new Tuple2<>(3, 2), + new Tuple2<>(3, 1) ); JavaPairRDD rdd = sc.parallelizePairs(pairs); JavaPairRDD counts = rdd.reduceByKey( @@ -690,7 +691,7 @@ public void cartesian() { JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); JavaRDD stringRDD = sc.parallelize(Arrays.asList("Hello", "World")); JavaPairRDD cartesian = stringRDD.cartesian(doubleRDD); - Assert.assertEquals(new Tuple2("Hello", 1.0), cartesian.first()); + Assert.assertEquals(new Tuple2<>("Hello", 1.0), cartesian.first()); } @Test @@ -743,6 +744,7 @@ public void javaDoubleRDDHistoGram() { } private static class DoubleComparator implements Comparator, Serializable { + @Override public int compare(Double o1, Double o2) { return o1.compareTo(o2); } @@ -766,14 +768,14 @@ public void min() { public void naturalMax() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); double max = rdd.max(); - Assert.assertTrue(4.0 == max); + Assert.assertEquals(4.0, max, 0.0); } @Test public void naturalMin() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); double max = rdd.min(); - Assert.assertTrue(1.0 == max); + Assert.assertEquals(1.0, max, 0.0); } @Test @@ -809,7 +811,7 @@ public void reduceOnJavaDoubleRDD() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); double sum = rdd.reduce(new Function2() { @Override - public Double call(Double v1, Double v2) throws Exception { + public Double call(Double v1, Double v2) { return v1 + v2; } }); @@ -844,7 +846,7 @@ public double call(Integer x) { new PairFunction() { @Override public Tuple2 call(Integer x) { - return new Tuple2(x, x); + return new Tuple2<>(x, x); } }).cache(); pairs.collect(); @@ -870,26 +872,25 @@ public Iterable call(String x) { Assert.assertEquals("Hello", words.first()); Assert.assertEquals(11, words.count()); - JavaPairRDD pairs = rdd.flatMapToPair( + JavaPairRDD pairsRDD = rdd.flatMapToPair( new PairFlatMapFunction() { - @Override public Iterable> call(String s) { - List> pairs = new LinkedList>(); + List> pairs = new LinkedList<>(); for (String word : s.split(" ")) { - pairs.add(new Tuple2(word, word)); + pairs.add(new Tuple2<>(word, word)); } return pairs; } } ); - Assert.assertEquals(new Tuple2("Hello", "Hello"), pairs.first()); - Assert.assertEquals(11, pairs.count()); + Assert.assertEquals(new Tuple2<>("Hello", "Hello"), pairsRDD.first()); + Assert.assertEquals(11, pairsRDD.count()); JavaDoubleRDD doubles = rdd.flatMapToDouble(new DoubleFlatMapFunction() { @Override public Iterable call(String s) { - List lengths = new LinkedList(); + List lengths = new LinkedList<>(); for (String word : s.split(" ")) { lengths.add((double) word.length()); } @@ -897,36 +898,36 @@ public Iterable call(String s) { } }); Assert.assertEquals(5.0, doubles.first(), 0.01); - Assert.assertEquals(11, pairs.count()); + Assert.assertEquals(11, pairsRDD.count()); } @SuppressWarnings("unchecked") @Test public void mapsFromPairsToPairs() { - List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") - ); - JavaPairRDD pairRDD = sc.parallelizePairs(pairs); - - // Regression test for SPARK-668: - JavaPairRDD swapped = pairRDD.flatMapToPair( - new PairFlatMapFunction, String, Integer>() { - @Override - public Iterable> call(Tuple2 item) { - return Collections.singletonList(item.swap()); - } + List> pairs = Arrays.asList( + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") + ); + JavaPairRDD pairRDD = sc.parallelizePairs(pairs); + + // Regression test for SPARK-668: + JavaPairRDD swapped = pairRDD.flatMapToPair( + new PairFlatMapFunction, String, Integer>() { + @Override + public Iterable> call(Tuple2 item) { + return Collections.singletonList(item.swap()); + } }); - swapped.collect(); + swapped.collect(); - // There was never a bug here, but it's worth testing: - pairRDD.mapToPair(new PairFunction, String, Integer>() { - @Override - public Tuple2 call(Tuple2 item) { - return item.swap(); - } - }).collect(); + // There was never a bug here, but it's worth testing: + pairRDD.mapToPair(new PairFunction, String, Integer>() { + @Override + public Tuple2 call(Tuple2 item) { + return item.swap(); + } + }).collect(); } @Test @@ -953,7 +954,7 @@ public void mapPartitionsWithIndex() { JavaRDD partitionSums = rdd.mapPartitionsWithIndex( new Function2, Iterator>() { @Override - public Iterator call(Integer index, Iterator iter) throws Exception { + public Iterator call(Integer index, Iterator iter) { int sum = 0; while (iter.hasNext()) { sum += iter.next(); @@ -972,8 +973,8 @@ public void repartition() { JavaRDD repartitioned1 = in1.repartition(4); List> result1 = repartitioned1.glom().collect(); Assert.assertEquals(4, result1.size()); - for (List l: result1) { - Assert.assertTrue(l.size() > 0); + for (List l : result1) { + Assert.assertFalse(l.isEmpty()); } // Growing number of partitions @@ -982,7 +983,7 @@ public void repartition() { List> result2 = repartitioned2.glom().collect(); Assert.assertEquals(2, result2.size()); for (List l: result2) { - Assert.assertTrue(l.size() > 0); + Assert.assertFalse(l.isEmpty()); } } @@ -994,9 +995,9 @@ public void persist() { Assert.assertEquals(20, doubleRDD.sum(), 0.1); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD pairRDD = sc.parallelizePairs(pairs); pairRDD = pairRDD.persist(StorageLevel.DISK_ONLY()); @@ -1046,7 +1047,7 @@ public void wholeTextFiles() throws Exception { Files.write(content1, new File(tempDirName + "/part-00000")); Files.write(content2, new File(tempDirName + "/part-00001")); - Map container = new HashMap(); + Map container = new HashMap<>(); container.put(tempDirName+"/part-00000", new Text(content1).toString()); container.put(tempDirName+"/part-00001", new Text(content2).toString()); @@ -1075,16 +1076,16 @@ public void textFilesCompressed() throws IOException { public void sequenceFile() { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); @@ -1093,7 +1094,7 @@ public Tuple2 call(Tuple2 pair) { Text.class).mapToPair(new PairFunction, Integer, String>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(pair._1().get(), pair._2().toString()); + return new Tuple2<>(pair._1().get(), pair._2().toString()); } }); Assert.assertEquals(pairs, readRDD.collect()); @@ -1110,7 +1111,7 @@ public void binaryFiles() throws Exception { FileOutputStream fos1 = new FileOutputStream(file1); FileChannel channel1 = fos1.getChannel(); - ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1); + ByteBuffer bbuf = ByteBuffer.wrap(content1); channel1.write(bbuf); channel1.close(); JavaPairRDD readRDD = sc.binaryFiles(tempDirName, 3); @@ -1131,14 +1132,14 @@ public void binaryFilesCaching() throws Exception { FileOutputStream fos1 = new FileOutputStream(file1); FileChannel channel1 = fos1.getChannel(); - ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1); + ByteBuffer bbuf = ByteBuffer.wrap(content1); channel1.write(bbuf); channel1.close(); JavaPairRDD readRDD = sc.binaryFiles(tempDirName).cache(); readRDD.foreach(new VoidFunction>() { @Override - public void call(Tuple2 pair) throws Exception { + public void call(Tuple2 pair) { pair._2().toArray(); // force the file to read } }); @@ -1162,7 +1163,7 @@ public void binaryRecords() throws Exception { FileChannel channel1 = fos1.getChannel(); for (int i = 0; i < numOfCopies; i++) { - ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1); + ByteBuffer bbuf = ByteBuffer.wrap(content1); channel1.write(bbuf); } channel1.close(); @@ -1180,24 +1181,23 @@ public void binaryRecords() throws Exception { public void writeWithNewAPIHadoopFile() { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } - }).saveAsNewAPIHadoopFile(outputDir, IntWritable.class, Text.class, - org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); + }).saveAsNewAPIHadoopFile( + outputDir, IntWritable.class, Text.class, + org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); - JavaPairRDD output = sc.sequenceFile(outputDir, IntWritable.class, - Text.class); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { + JavaPairRDD output = sc.sequenceFile(outputDir, IntWritable.class, Text.class); + Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1210,24 +1210,23 @@ public String call(Tuple2 x) { public void readWithNewAPIHadoopFile() throws IOException { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); JavaPairRDD output = sc.newAPIHadoopFile(outputDir, - org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, IntWritable.class, - Text.class, new Job().getConfiguration()); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { + org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, + IntWritable.class, Text.class, new Job().getConfiguration()); + Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1251,9 +1250,9 @@ public void objectFilesOfInts() { public void objectFilesOfComplexTypes() { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.saveAsObjectFile(outputDir); @@ -1267,23 +1266,22 @@ public void objectFilesOfComplexTypes() { public void hadoopFile() { String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); JavaPairRDD output = sc.hadoopFile(outputDir, - SequenceFileInputFormat.class, IntWritable.class, Text.class); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { + SequenceFileInputFormat.class, IntWritable.class, Text.class); + Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1296,16 +1294,16 @@ public String call(Tuple2 x) { public void hadoopFileCompressed() { String outputDir = new File(tempDir, "output_compressed").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { - return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); + return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); } }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class, DefaultCodec.class); @@ -1313,8 +1311,7 @@ public Tuple2 call(Tuple2 pair) { JavaPairRDD output = sc.hadoopFile(outputDir, SequenceFileInputFormat.class, IntWritable.class, Text.class); - Assert.assertEquals(pairs.toString(), output.map(new Function, - String>() { + Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { return x.toString(); @@ -1414,8 +1411,8 @@ public String call(Integer t) { return t.toString(); } }).collect(); - Assert.assertEquals(new Tuple2("1", 1), s.get(0)); - Assert.assertEquals(new Tuple2("2", 2), s.get(1)); + Assert.assertEquals(new Tuple2<>("1", 1), s.get(0)); + Assert.assertEquals(new Tuple2<>("2", 2), s.get(1)); } @Test @@ -1448,20 +1445,20 @@ public void combineByKey() { JavaRDD originalRDD = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6)); Function keyFunction = new Function() { @Override - public Integer call(Integer v1) throws Exception { + public Integer call(Integer v1) { return v1 % 3; } }; Function createCombinerFunction = new Function() { @Override - public Integer call(Integer v1) throws Exception { + public Integer call(Integer v1) { return v1; } }; Function2 mergeValueFunction = new Function2() { @Override - public Integer call(Integer v1, Integer v2) throws Exception { + public Integer call(Integer v1, Integer v2) { return v1 + v2; } }; @@ -1496,21 +1493,21 @@ public void mapOnPairRDD() { new PairFunction() { @Override public Tuple2 call(Integer i) { - return new Tuple2(i, i % 2); + return new Tuple2<>(i, i % 2); } }); JavaPairRDD rdd3 = rdd2.mapToPair( new PairFunction, Integer, Integer>() { - @Override - public Tuple2 call(Tuple2 in) { - return new Tuple2(in._2(), in._1()); - } - }); + @Override + public Tuple2 call(Tuple2 in) { + return new Tuple2<>(in._2(), in._1()); + } + }); Assert.assertEquals(Arrays.asList( - new Tuple2(1, 1), - new Tuple2(0, 2), - new Tuple2(1, 3), - new Tuple2(0, 4)), rdd3.collect()); + new Tuple2<>(1, 1), + new Tuple2<>(0, 2), + new Tuple2<>(1, 3), + new Tuple2<>(0, 4)), rdd3.collect()); } @@ -1523,7 +1520,7 @@ public void collectPartitions() { new PairFunction() { @Override public Tuple2 call(Integer i) { - return new Tuple2(i, i % 2); + return new Tuple2<>(i, i % 2); } }); @@ -1534,23 +1531,23 @@ public Tuple2 call(Integer i) { Assert.assertEquals(Arrays.asList(3, 4), parts[0]); Assert.assertEquals(Arrays.asList(5, 6, 7), parts[1]); - Assert.assertEquals(Arrays.asList(new Tuple2(1, 1), - new Tuple2(2, 0)), + Assert.assertEquals(Arrays.asList(new Tuple2<>(1, 1), + new Tuple2<>(2, 0)), rdd2.collectPartitions(new int[] {0})[0]); List>[] parts2 = rdd2.collectPartitions(new int[] {1, 2}); - Assert.assertEquals(Arrays.asList(new Tuple2(3, 1), - new Tuple2(4, 0)), + Assert.assertEquals(Arrays.asList(new Tuple2<>(3, 1), + new Tuple2<>(4, 0)), parts2[0]); - Assert.assertEquals(Arrays.asList(new Tuple2(5, 1), - new Tuple2(6, 0), - new Tuple2(7, 1)), + Assert.assertEquals(Arrays.asList(new Tuple2<>(5, 1), + new Tuple2<>(6, 0), + new Tuple2<>(7, 1)), parts2[1]); } @Test public void countApproxDistinct() { - List arrayData = new ArrayList(); + List arrayData = new ArrayList<>(); int size = 100; for (int i = 0; i < 100000; i++) { arrayData.add(i % size); @@ -1561,15 +1558,15 @@ public void countApproxDistinct() { @Test public void countApproxDistinctByKey() { - List> arrayData = new ArrayList>(); + List> arrayData = new ArrayList<>(); for (int i = 10; i < 100; i++) { for (int j = 0; j < i; j++) { - arrayData.add(new Tuple2(i, j)); + arrayData.add(new Tuple2<>(i, j)); } } double relativeSD = 0.001; JavaPairRDD pairRdd = sc.parallelizePairs(arrayData); - List> res = pairRdd.countApproxDistinctByKey(8, 0).collect(); + List> res = pairRdd.countApproxDistinctByKey(relativeSD, 8).collect(); for (Tuple2 resItem : res) { double count = (double)resItem._1(); Long resCount = (Long)resItem._2(); @@ -1587,7 +1584,7 @@ public void collectAsMapWithIntArrayValues() { new PairFunction() { @Override public Tuple2 call(Integer x) { - return new Tuple2(x, new int[] { x }); + return new Tuple2<>(x, new int[]{x}); } }); pairRDD.collect(); // Works fine @@ -1598,7 +1595,7 @@ public Tuple2 call(Integer x) { @Test public void collectAsMapAndSerialize() throws Exception { JavaPairRDD rdd = - sc.parallelizePairs(Arrays.asList(new Tuple2("foo", 1))); + sc.parallelizePairs(Arrays.asList(new Tuple2<>("foo", 1))); Map map = rdd.collectAsMap(); ByteArrayOutputStream bytes = new ByteArrayOutputStream(); new ObjectOutputStream(bytes).writeObject(map); @@ -1615,7 +1612,7 @@ public void sampleByKey() { new PairFunction() { @Override public Tuple2 call(Integer i) { - return new Tuple2(i % 2, 1); + return new Tuple2<>(i % 2, 1); } }); Map fractions = Maps.newHashMap(); @@ -1623,12 +1620,12 @@ public Tuple2 call(Integer i) { fractions.put(1, 1.0); JavaPairRDD wr = rdd2.sampleByKey(true, fractions, 1L); Map wrCounts = (Map) (Object) wr.countByKey(); - Assert.assertTrue(wrCounts.size() == 2); + Assert.assertEquals(2, wrCounts.size()); Assert.assertTrue(wrCounts.get(0) > 0); Assert.assertTrue(wrCounts.get(1) > 0); JavaPairRDD wor = rdd2.sampleByKey(false, fractions, 1L); Map worCounts = (Map) (Object) wor.countByKey(); - Assert.assertTrue(worCounts.size() == 2); + Assert.assertEquals(2, worCounts.size()); Assert.assertTrue(worCounts.get(0) > 0); Assert.assertTrue(worCounts.get(1) > 0); } @@ -1641,7 +1638,7 @@ public void sampleByKeyExact() { new PairFunction() { @Override public Tuple2 call(Integer i) { - return new Tuple2(i % 2, 1); + return new Tuple2<>(i % 2, 1); } }); Map fractions = Maps.newHashMap(); @@ -1649,25 +1646,25 @@ public Tuple2 call(Integer i) { fractions.put(1, 1.0); JavaPairRDD wrExact = rdd2.sampleByKeyExact(true, fractions, 1L); Map wrExactCounts = (Map) (Object) wrExact.countByKey(); - Assert.assertTrue(wrExactCounts.size() == 2); + Assert.assertEquals(2, wrExactCounts.size()); Assert.assertTrue(wrExactCounts.get(0) == 2); Assert.assertTrue(wrExactCounts.get(1) == 4); JavaPairRDD worExact = rdd2.sampleByKeyExact(false, fractions, 1L); Map worExactCounts = (Map) (Object) worExact.countByKey(); - Assert.assertTrue(worExactCounts.size() == 2); + Assert.assertEquals(2, worExactCounts.size()); Assert.assertTrue(worExactCounts.get(0) == 2); Assert.assertTrue(worExactCounts.get(1) == 4); } private static class SomeCustomClass implements Serializable { - public SomeCustomClass() { + SomeCustomClass() { // Intentionally left blank } } @Test public void collectUnderlyingScalaRDD() { - List data = new ArrayList(); + List data = new ArrayList<>(); for (int i = 0; i < 100; i++) { data.add(new SomeCustomClass()); } @@ -1679,7 +1676,7 @@ public void collectUnderlyingScalaRDD() { private static final class BuggyMapFunction implements Function { @Override - public T call(T x) throws Exception { + public T call(T x) { throw new IllegalStateException("Custom exception!"); } } @@ -1716,7 +1713,7 @@ public void foreachAsync() throws Exception { JavaFutureAction future = rdd.foreachAsync( new VoidFunction() { @Override - public void call(Integer integer) throws Exception { + public void call(Integer integer) { // intentionally left blank. } } @@ -1745,7 +1742,7 @@ public void testAsyncActionCancellation() throws Exception { JavaRDD rdd = sc.parallelize(data, 1); JavaFutureAction future = rdd.foreachAsync(new VoidFunction() { @Override - public void call(Integer integer) throws Exception { + public void call(Integer integer) throws InterruptedException { Thread.sleep(10000); // To ensure that the job won't finish before it's cancelled. } }); diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java index 9db07d0507fea..fbdfbf7e509b3 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java @@ -75,11 +75,11 @@ public void testKafkaStream() throws InterruptedException { String[] topic1data = createTopicAndSendData(topic1); String[] topic2data = createTopicAndSendData(topic2); - HashSet sent = new HashSet(); + Set sent = new HashSet<>(); sent.addAll(Arrays.asList(topic1data)); sent.addAll(Arrays.asList(topic2data)); - HashMap kafkaParams = new HashMap(); + Map kafkaParams = new HashMap<>(); kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); kafkaParams.put("auto.offset.reset", "smallest"); @@ -95,17 +95,17 @@ public void testKafkaStream() throws InterruptedException { // Make sure you can get offset ranges from the rdd new Function, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD rdd) throws Exception { + public JavaPairRDD call(JavaPairRDD rdd) { OffsetRange[] offsets = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); offsetRanges.set(offsets); - Assert.assertEquals(offsets[0].topic(), topic1); + Assert.assertEquals(topic1, offsets[0].topic()); return rdd; } } ).map( new Function, String>() { @Override - public String call(Tuple2 kv) throws Exception { + public String call(Tuple2 kv) { return kv._2(); } } @@ -119,10 +119,10 @@ public String call(Tuple2 kv) throws Exception { StringDecoder.class, String.class, kafkaParams, - topicOffsetToMap(topic2, (long) 0), + topicOffsetToMap(topic2, 0L), new Function, String>() { @Override - public String call(MessageAndMetadata msgAndMd) throws Exception { + public String call(MessageAndMetadata msgAndMd) { return msgAndMd.message(); } } @@ -133,7 +133,7 @@ public String call(MessageAndMetadata msgAndMd) throws Exception unifiedStream.foreachRDD( new Function, Void>() { @Override - public Void call(JavaRDD rdd) throws Exception { + public Void call(JavaRDD rdd) { result.addAll(rdd.collect()); for (OffsetRange o : offsetRanges.get()) { System.out.println( @@ -155,14 +155,14 @@ public Void call(JavaRDD rdd) throws Exception { ssc.stop(); } - private HashSet topicToSet(String topic) { - HashSet topicSet = new HashSet(); + private static Set topicToSet(String topic) { + Set topicSet = new HashSet<>(); topicSet.add(topic); return topicSet; } - private HashMap topicOffsetToMap(String topic, Long offsetToStart) { - HashMap topicMap = new HashMap(); + private static Map topicOffsetToMap(String topic, Long offsetToStart) { + Map topicMap = new HashMap<>(); topicMap.put(new TopicAndPartition(topic, 0), offsetToStart); return topicMap; } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java index a9dc6e50613ca..afcc6cfccd39a 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java @@ -19,6 +19,7 @@ import java.io.Serializable; import java.util.HashMap; +import java.util.Map; import scala.Tuple2; @@ -66,10 +67,10 @@ public void testKafkaRDD() throws InterruptedException { String topic1 = "topic1"; String topic2 = "topic2"; - String[] topic1data = createTopicAndSendData(topic1); - String[] topic2data = createTopicAndSendData(topic2); + createTopicAndSendData(topic1); + createTopicAndSendData(topic2); - HashMap kafkaParams = new HashMap(); + Map kafkaParams = new HashMap<>(); kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); OffsetRange[] offsetRanges = { @@ -77,8 +78,8 @@ public void testKafkaRDD() throws InterruptedException { OffsetRange.create(topic2, 0, 0, 1) }; - HashMap emptyLeaders = new HashMap(); - HashMap leaders = new HashMap(); + Map emptyLeaders = new HashMap<>(); + Map leaders = new HashMap<>(); String[] hostAndPort = kafkaTestUtils.brokerAddress().split(":"); Broker broker = Broker.create(hostAndPort[0], Integer.parseInt(hostAndPort[1])); leaders.put(new TopicAndPartition(topic1, 0), broker); @@ -95,7 +96,7 @@ public void testKafkaRDD() throws InterruptedException { ).map( new Function, String>() { @Override - public String call(Tuple2 kv) throws Exception { + public String call(Tuple2 kv) { return kv._2(); } } @@ -113,7 +114,7 @@ public String call(Tuple2 kv) throws Exception { emptyLeaders, new Function, String>() { @Override - public String call(MessageAndMetadata msgAndMd) throws Exception { + public String call(MessageAndMetadata msgAndMd) { return msgAndMd.message(); } } @@ -131,7 +132,7 @@ public String call(MessageAndMetadata msgAndMd) throws Exception leaders, new Function, String>() { @Override - public String call(MessageAndMetadata msgAndMd) throws Exception { + public String call(MessageAndMetadata msgAndMd) { return msgAndMd.message(); } } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index e4c659215b767..1e69de46cd35d 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -67,10 +67,10 @@ public void tearDown() { @Test public void testKafkaStream() throws InterruptedException { String topic = "topic1"; - HashMap topics = new HashMap(); + Map topics = new HashMap<>(); topics.put(topic, 1); - HashMap sent = new HashMap(); + Map sent = new HashMap<>(); sent.put("a", 5); sent.put("b", 3); sent.put("c", 10); @@ -78,7 +78,7 @@ public void testKafkaStream() throws InterruptedException { kafkaTestUtils.createTopic(topic); kafkaTestUtils.sendMessages(topic, sent); - HashMap kafkaParams = new HashMap(); + Map kafkaParams = new HashMap<>(); kafkaParams.put("zookeeper.connect", kafkaTestUtils.zkAddress()); kafkaParams.put("group.id", "test-consumer-" + random.nextInt(10000)); kafkaParams.put("auto.offset.reset", "smallest"); @@ -97,7 +97,7 @@ public void testKafkaStream() throws InterruptedException { JavaDStream words = stream.map( new Function, String>() { @Override - public String call(Tuple2 tuple2) throws Exception { + public String call(Tuple2 tuple2) { return tuple2._2(); } } @@ -106,7 +106,7 @@ public String call(Tuple2 tuple2) throws Exception { words.countByValue().foreachRDD( new Function, Void>() { @Override - public Void call(JavaPairRDD rdd) throws Exception { + public Void call(JavaPairRDD rdd) { List> ret = rdd.collect(); for (Tuple2 r : ret) { if (result.containsKey(r._1())) { @@ -130,8 +130,8 @@ public Void call(JavaPairRDD rdd) throws Exception { Thread.sleep(200); } Assert.assertEquals(sent.size(), result.size()); - for (String k : sent.keySet()) { - Assert.assertEquals(sent.get(k).intValue(), result.get(k).intValue()); + for (Map.Entry e : sent.entrySet()) { + Assert.assertEquals(e.getValue().intValue(), result.get(e.getKey()).intValue()); } } } diff --git a/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java b/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java index e46b4e5c7531d..26ec8af455bcf 100644 --- a/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java +++ b/external/twitter/src/test/java/org/apache/spark/streaming/twitter/JavaTwitterStreamSuite.java @@ -17,8 +17,6 @@ package org.apache.spark.streaming.twitter; -import java.util.Arrays; - import org.junit.Test; import twitter4j.Status; import twitter4j.auth.Authorization; @@ -30,7 +28,7 @@ public class JavaTwitterStreamSuite extends LocalJavaStreamingContext { @Test public void testTwitterStream() { - String[] filters = (String[])Arrays.asList("filter1", "filter2").toArray(); + String[] filters = { "filter1", "filter2" }; Authorization auth = NullAuthorization.getInstance(); // tests the API, does not actually test data receiving diff --git a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java index 729bc0459ce52..14975265ab2ce 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java @@ -77,7 +77,7 @@ public void call(String s) { public void foreach() { foreachCalls = 0; JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); - rdd.foreach((x) -> foreachCalls++); + rdd.foreach(x -> foreachCalls++); Assert.assertEquals(2, foreachCalls); } @@ -180,7 +180,7 @@ public void map() { JavaPairRDD pairs = rdd.mapToPair(x -> new Tuple2<>(x, x)) .cache(); pairs.collect(); - JavaRDD strings = rdd.map(x -> x.toString()).cache(); + JavaRDD strings = rdd.map(Object::toString).cache(); strings.collect(); } @@ -195,7 +195,9 @@ public void flatMap() { JavaPairRDD pairs = rdd.flatMapToPair(s -> { List> pairs2 = new LinkedList<>(); - for (String word : s.split(" ")) pairs2.add(new Tuple2<>(word, word)); + for (String word : s.split(" ")) { + pairs2.add(new Tuple2<>(word, word)); + } return pairs2; }); @@ -204,11 +206,12 @@ public void flatMap() { JavaDoubleRDD doubles = rdd.flatMapToDouble(s -> { List lengths = new LinkedList<>(); - for (String word : s.split(" ")) lengths.add(word.length() * 1.0); + for (String word : s.split(" ")) { + lengths.add((double) word.length()); + } return lengths; }); - Double x = doubles.first(); Assert.assertEquals(5.0, doubles.first(), 0.01); Assert.assertEquals(11, pairs.count()); } @@ -228,7 +231,7 @@ public void mapsFromPairsToPairs() { swapped.collect(); // There was never a bug here, but it's worth testing: - pairRDD.map(item -> item.swap()).collect(); + pairRDD.map(Tuple2::swap).collect(); } @Test @@ -282,11 +285,11 @@ public void zipPartitions() { FlatMapFunction2, Iterator, Integer> sizesFn = (Iterator i, Iterator s) -> { int sizeI = 0; - int sizeS = 0; while (i.hasNext()) { sizeI += 1; i.next(); } + int sizeS = 0; while (s.hasNext()) { sizeS += 1; s.next(); @@ -301,30 +304,31 @@ public void zipPartitions() { public void accumulators() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - final Accumulator intAccum = sc.intAccumulator(10); - rdd.foreach(x -> intAccum.add(x)); + Accumulator intAccum = sc.intAccumulator(10); + rdd.foreach(intAccum::add); Assert.assertEquals((Integer) 25, intAccum.value()); - final Accumulator doubleAccum = sc.doubleAccumulator(10.0); + Accumulator doubleAccum = sc.doubleAccumulator(10.0); rdd.foreach(x -> doubleAccum.add((double) x)); Assert.assertEquals((Double) 25.0, doubleAccum.value()); // Try a custom accumulator type AccumulatorParam floatAccumulatorParam = new AccumulatorParam() { + @Override public Float addInPlace(Float r, Float t) { return r + t; } - + @Override public Float addAccumulator(Float r, Float t) { return r + t; } - + @Override public Float zero(Float initialValue) { return 0.0f; } }; - final Accumulator floatAccum = sc.accumulator(10.0f, floatAccumulatorParam); + Accumulator floatAccum = sc.accumulator(10.0f, floatAccumulatorParam); rdd.foreach(x -> floatAccum.add((float) x)); Assert.assertEquals((Float) 25.0f, floatAccum.value()); @@ -336,7 +340,7 @@ public Float zero(Float initialValue) { @Test public void keyBy() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2)); - List> s = rdd.keyBy(x -> x.toString()).collect(); + List> s = rdd.keyBy(Object::toString).collect(); Assert.assertEquals(new Tuple2<>("1", 1), s.get(0)); Assert.assertEquals(new Tuple2<>("2", 2), s.get(1)); } @@ -349,7 +353,7 @@ public void mapOnPairRDD() { JavaPairRDD rdd3 = rdd2.mapToPair(in -> new Tuple2<>(in._2(), in._1())); Assert.assertEquals(Arrays.asList( - new Tuple2(1, 1), + new Tuple2<>(1, 1), new Tuple2<>(0, 2), new Tuple2<>(1, 3), new Tuple2<>(0, 4)), rdd3.collect()); @@ -361,7 +365,7 @@ public void collectPartitions() { JavaPairRDD rdd2 = rdd1.mapToPair(i -> new Tuple2<>(i, i % 2)); - List[] parts = rdd1.collectPartitions(new int[]{0}); + List[] parts = rdd1.collectPartitions(new int[]{0}); Assert.assertEquals(Arrays.asList(1, 2), parts[0]); parts = rdd1.collectPartitions(new int[]{1, 2}); @@ -371,19 +375,19 @@ public void collectPartitions() { Assert.assertEquals(Arrays.asList(new Tuple2<>(1, 1), new Tuple2<>(2, 0)), rdd2.collectPartitions(new int[]{0})[0]); - parts = rdd2.collectPartitions(new int[]{1, 2}); - Assert.assertEquals(Arrays.asList(new Tuple2<>(3, 1), new Tuple2<>(4, 0)), parts[0]); + List>[] parts2 = rdd2.collectPartitions(new int[]{1, 2}); + Assert.assertEquals(Arrays.asList(new Tuple2<>(3, 1), new Tuple2<>(4, 0)), parts2[0]); Assert.assertEquals(Arrays.asList(new Tuple2<>(5, 1), new Tuple2<>(6, 0), new Tuple2<>(7, 1)), - parts[1]); + parts2[1]); } @Test public void collectAsMapWithIntArrayValues() { // Regression test for SPARK-1040 - JavaRDD rdd = sc.parallelize(Arrays.asList(new Integer[]{1})); + JavaRDD rdd = sc.parallelize(Arrays.asList(1)); JavaPairRDD pairRDD = rdd.mapToPair(x -> new Tuple2<>(x, new int[]{x})); pairRDD.collect(); // Works fine - Map map = pairRDD.collectAsMap(); // Used to crash with ClassCastException + pairRDD.collectAsMap(); // Used to crash with ClassCastException } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index bf693c7c393f6..7b50aad4ad498 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -18,6 +18,7 @@ package test.org.apache.spark.sql; import java.io.Serializable; +import java.math.BigDecimal; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -83,7 +84,7 @@ public void setAge(int age) { @Test public void applySchema() { - List personList = new ArrayList(2); + List personList = new ArrayList<>(2); Person person1 = new Person(); person1.setName("Michael"); person1.setAge(29); @@ -95,12 +96,13 @@ public void applySchema() { JavaRDD rowRDD = javaCtx.parallelize(personList).map( new Function() { + @Override public Row call(Person person) throws Exception { return RowFactory.create(person.getName(), person.getAge()); } }); - List fields = new ArrayList(2); + List fields = new ArrayList<>(2); fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); @@ -118,7 +120,7 @@ public Row call(Person person) throws Exception { @Test public void dataFrameRDDOperations() { - List personList = new ArrayList(2); + List personList = new ArrayList<>(2); Person person1 = new Person(); person1.setName("Michael"); person1.setAge(29); @@ -129,27 +131,28 @@ public void dataFrameRDDOperations() { personList.add(person2); JavaRDD rowRDD = javaCtx.parallelize(personList).map( - new Function() { - public Row call(Person person) throws Exception { - return RowFactory.create(person.getName(), person.getAge()); - } - }); - - List fields = new ArrayList(2); - fields.add(DataTypes.createStructField("name", DataTypes.StringType, false)); + new Function() { + @Override + public Row call(Person person) { + return RowFactory.create(person.getName(), person.getAge()); + } + }); + + List fields = new ArrayList<>(2); + fields.add(DataTypes.createStructField("", DataTypes.StringType, false)); fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); DataFrame df = sqlContext.applySchema(rowRDD, schema); df.registerTempTable("people"); List actual = sqlContext.sql("SELECT * FROM people").toJavaRDD().map(new Function() { - + @Override public String call(Row row) { - return row.getString(0) + "_" + row.get(1).toString(); + return row.getString(0) + "_" + row.get(1); } }).collect(); - List expected = new ArrayList(2); + List expected = new ArrayList<>(2); expected.add("Michael_29"); expected.add("Yin_28"); @@ -165,7 +168,7 @@ public void applySchemaToJSON() { "{\"string\":\"this is another simple string.\", \"integer\":11, \"long\":21474836469, " + "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + "\"boolean\":false, \"null\":null}")); - List fields = new ArrayList(7); + List fields = new ArrayList<>(7); fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(20, 0), true)); fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true)); @@ -175,10 +178,10 @@ public void applySchemaToJSON() { fields.add(DataTypes.createStructField("null", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("string", DataTypes.StringType, true)); StructType expectedSchema = DataTypes.createStructType(fields); - List expectedResult = new ArrayList(2); + List expectedResult = new ArrayList<>(2); expectedResult.add( RowFactory.create( - new java.math.BigDecimal("92233720368547758070"), + new BigDecimal("92233720368547758070"), true, 1.7976931348623157E308, 10, @@ -187,7 +190,7 @@ public void applySchemaToJSON() { "this is a simple string.")); expectedResult.add( RowFactory.create( - new java.math.BigDecimal("92233720368547758069"), + new BigDecimal("92233720368547758069"), false, 1.7976931348623157E305, 11, diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 4867cebf5328c..d981ce947f435 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -61,7 +61,7 @@ public void tearDown() { @Test public void testExecution() { DataFrame df = context.table("testData").filter("key = 1"); - Assert.assertEquals(df.select("key").collect()[0].get(0), 1); + Assert.assertEquals(1, df.select("key").collect()[0].get(0)); } /** @@ -119,7 +119,7 @@ public void testShow() { public static class Bean implements Serializable { private double a = 0.0; - private Integer[] b = new Integer[]{0, 1}; + private Integer[] b = { 0, 1 }; private Map c = ImmutableMap.of("hello", new int[] { 1, 2 }); private List d = Arrays.asList("floppy", "disk"); @@ -161,7 +161,7 @@ public void testCreateDataFrameFromJavaBeans() { schema.apply("d")); Row first = df.select("a", "b", "c", "d").first(); Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); - // Now Java lists and maps are converetd to Scala Seq's and Map's. Once we get a Seq below, + // Now Java lists and maps are converted to Scala Seq's and Map's. Once we get a Seq below, // verify that it has the expected length, and contains expected elements. Seq result = first.getAs(1); Assert.assertEquals(bean.getB().length, result.length()); @@ -180,7 +180,8 @@ public void testCreateDataFrameFromJavaBeans() { } } - private static Comparator CrosstabRowComparator = new Comparator() { + private static final Comparator crosstabRowComparator = new Comparator() { + @Override public int compare(Row row1, Row row2) { String item1 = row1.getString(0); String item2 = row2.getString(0); @@ -193,16 +194,16 @@ public void testCrosstab() { DataFrame df = context.table("testData2"); DataFrame crosstab = df.stat().crosstab("a", "b"); String[] columnNames = crosstab.schema().fieldNames(); - Assert.assertEquals(columnNames[0], "a_b"); - Assert.assertEquals(columnNames[1], "1"); - Assert.assertEquals(columnNames[2], "2"); + Assert.assertEquals("a_b", columnNames[0]); + Assert.assertEquals("1", columnNames[1]); + Assert.assertEquals("2", columnNames[2]); Row[] rows = crosstab.collect(); - Arrays.sort(rows, CrosstabRowComparator); + Arrays.sort(rows, crosstabRowComparator); Integer count = 1; for (Row row : rows) { Assert.assertEquals(row.get(0).toString(), count.toString()); - Assert.assertEquals(row.getLong(1), 1L); - Assert.assertEquals(row.getLong(2), 1L); + Assert.assertEquals(1L, row.getLong(1)); + Assert.assertEquals(1L, row.getLong(2)); count++; } } @@ -210,7 +211,7 @@ public void testCrosstab() { @Test public void testFrequentItems() { DataFrame df = context.table("testData2"); - String[] cols = new String[]{"a"}; + String[] cols = {"a"}; DataFrame results = df.stat().freqItems(cols, 0.2); Assert.assertTrue(results.collect()[0].getSeq(0).contains(1)); } @@ -219,14 +220,14 @@ public void testFrequentItems() { public void testCorrelation() { DataFrame df = context.table("testData2"); Double pearsonCorr = df.stat().corr("a", "b", "pearson"); - Assert.assertTrue(Math.abs(pearsonCorr) < 1e-6); + Assert.assertTrue(Math.abs(pearsonCorr) < 1.0e-6); } @Test public void testCovariance() { DataFrame df = context.table("testData2"); Double result = df.stat().cov("a", "b"); - Assert.assertTrue(Math.abs(result) < 1e-6); + Assert.assertTrue(Math.abs(result) < 1.0e-6); } @Test @@ -234,7 +235,7 @@ public void testSampleBy() { DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); DataFrame sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); - Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)}; + Row[] expected = {RowFactory.create(0, 5), RowFactory.create(1, 8)}; Assert.assertArrayEquals(expected, actual); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java index 4ce1d1dddb26a..3ab4db2a035d3 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java @@ -18,6 +18,7 @@ package test.org.apache.spark.sql; import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; import java.sql.Date; import java.sql.Timestamp; import java.util.Arrays; @@ -52,12 +53,12 @@ public void setUp() { shortValue = (short)32767; intValue = 2147483647; longValue = 9223372036854775807L; - floatValue = (float)3.4028235E38; + floatValue = 3.4028235E38f; doubleValue = 1.7976931348623157E308; decimalValue = new BigDecimal("1.7976931348623157E328"); booleanValue = true; stringValue = "this is a string"; - binaryValue = stringValue.getBytes(); + binaryValue = stringValue.getBytes(StandardCharsets.UTF_8); dateValue = Date.valueOf("2014-06-30"); timestampValue = Timestamp.valueOf("2014-06-30 09:20:00.0"); } @@ -123,8 +124,8 @@ public void constructSimpleRow() { Assert.assertEquals(binaryValue, simpleRow.get(16)); Assert.assertEquals(dateValue, simpleRow.get(17)); Assert.assertEquals(timestampValue, simpleRow.get(18)); - Assert.assertEquals(true, simpleRow.isNullAt(19)); - Assert.assertEquals(null, simpleRow.get(19)); + Assert.assertTrue(simpleRow.isNullAt(19)); + Assert.assertNull(simpleRow.get(19)); } @Test @@ -134,7 +135,7 @@ public void constructComplexRow() { stringValue + " (1)", stringValue + " (2)", stringValue + "(3)"); // Simple map - Map simpleMap = new HashMap(); + Map simpleMap = new HashMap<>(); simpleMap.put(stringValue + " (1)", longValue); simpleMap.put(stringValue + " (2)", longValue - 1); simpleMap.put(stringValue + " (3)", longValue - 2); @@ -149,7 +150,7 @@ public void constructComplexRow() { List arrayOfRows = Arrays.asList(simpleStruct); // Complex map - Map, Row> complexMap = new HashMap, Row>(); + Map, Row> complexMap = new HashMap<>(); complexMap.put(arrayOfRows, simpleStruct); // Complex struct @@ -167,7 +168,7 @@ public void constructComplexRow() { Assert.assertEquals(arrayOfMaps, complexStruct.get(3)); Assert.assertEquals(arrayOfRows, complexStruct.get(4)); Assert.assertEquals(complexMap, complexStruct.get(5)); - Assert.assertEquals(null, complexStruct.get(6)); + Assert.assertNull(complexStruct.get(6)); // A very complex row Row complexRow = RowFactory.create(arrayOfMaps, arrayOfRows, complexMap, complexStruct); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index bb02b58cca9be..4a78dca7fea66 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -20,6 +20,7 @@ import java.io.Serializable; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -61,13 +62,13 @@ public void udf1Test() { sqlContext.udf().register("stringLengthTest", new UDF1() { @Override - public Integer call(String str) throws Exception { + public Integer call(String str) { return str.length(); } }, DataTypes.IntegerType); Row result = sqlContext.sql("SELECT stringLengthTest('test')").head(); - assert(result.getInt(0) == 4); + Assert.assertEquals(4, result.getInt(0)); } @SuppressWarnings("unchecked") @@ -81,12 +82,12 @@ public void udf2Test() { sqlContext.udf().register("stringLengthTest", new UDF2() { @Override - public Integer call(String str1, String str2) throws Exception { + public Integer call(String str1, String str2) { return str1.length() + str2.length(); } }, DataTypes.IntegerType); Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head(); - assert(result.getInt(0) == 9); + Assert.assertEquals(9, result.getInt(0)); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index 6f9e7f68dc39c..9e241f20987c0 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -44,7 +44,7 @@ public class JavaSaveLoadSuite { File path; DataFrame df; - private void checkAnswer(DataFrame actual, List expected) { + private static void checkAnswer(DataFrame actual, List expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -64,7 +64,7 @@ public void setUp() throws IOException { path.delete(); } - List jsonObjects = new ArrayList(10); + List jsonObjects = new ArrayList<>(10); for (int i = 0; i < 10; i++) { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } @@ -82,7 +82,7 @@ public void tearDown() { @Test public void saveAndLoad() { - Map options = new HashMap(); + Map options = new HashMap<>(); options.put("path", path.toString()); df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save(); DataFrame loadedDF = sqlContext.read().format("json").options(options).load(); @@ -91,11 +91,11 @@ public void saveAndLoad() { @Test public void saveAndLoadWithSchema() { - Map options = new HashMap(); + Map options = new HashMap<>(); options.put("path", path.toString()); df.write().format("json").mode(SaveMode.ErrorIfExists).options(options).save(); - List fields = new ArrayList(); + List fields = new ArrayList<>(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); DataFrame loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java index 019d8a30266e2..b4bf9eef8fca5 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -40,7 +40,7 @@ public class JavaDataFrameSuite { DataFrame df; - private void checkAnswer(DataFrame actual, List expected) { + private static void checkAnswer(DataFrame actual, List expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -52,7 +52,7 @@ public void setUp() throws IOException { hc = TestHive$.MODULE$; sc = new JavaSparkContext(hc.sparkContext()); - List jsonObjects = new ArrayList(10); + List jsonObjects = new ArrayList<>(10); for (int i = 0; i < 10; i++) { jsonObjects.add("{\"key\":" + i + ", \"value\":\"str" + i + "\"}"); } @@ -71,7 +71,7 @@ public void tearDown() throws IOException { @Test public void saveTableAndQueryIt() { checkAnswer( - df.select(functions.avg("key").over( + df.select(avg("key").over( Window.partitionBy("value").orderBy("key").rowsBetween(-1, 1))), hc.sql("SELECT avg(key) " + "OVER (PARTITION BY value " + @@ -95,7 +95,7 @@ public void testUDAF() { registeredUDAF.apply(col("value")), callUDF("mydoublesum", col("value"))); - List expectedResult = new ArrayList(); + List expectedResult = new ArrayList<>(); expectedResult.add(RowFactory.create(4950.0, 9900.0, 9900.0, 9900.0)); checkAnswer( aggregatedDF, diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 4192155975c47..c8d272794d10b 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -53,7 +53,7 @@ public class JavaMetastoreDataSourcesSuite { FileSystem fs; DataFrame df; - private void checkAnswer(DataFrame actual, List expected) { + private static void checkAnswer(DataFrame actual, List expected) { String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); if (errorMessage != null) { Assert.fail(errorMessage); @@ -77,7 +77,7 @@ public void setUp() throws IOException { fs.delete(hiveManagedPath, true); } - List jsonObjects = new ArrayList(10); + List jsonObjects = new ArrayList<>(10); for (int i = 0; i < 10; i++) { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } @@ -97,7 +97,7 @@ public void tearDown() throws IOException { @Test public void saveExternalTableAndQueryIt() { - Map options = new HashMap(); + Map options = new HashMap<>(); options.put("path", path.toString()); df.write() .format("org.apache.spark.sql.json") @@ -120,7 +120,7 @@ public void saveExternalTableAndQueryIt() { @Test public void saveExternalTableWithSchemaAndQueryIt() { - Map options = new HashMap(); + Map options = new HashMap<>(); options.put("path", path.toString()); df.write() .format("org.apache.spark.sql.json") @@ -132,7 +132,7 @@ public void saveExternalTableWithSchemaAndQueryIt() { sqlContext.sql("SELECT * FROM javaSavedTable"), df.collectAsList()); - List fields = new ArrayList(); + List fields = new ArrayList<>(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); DataFrame loadedDF = @@ -148,7 +148,7 @@ public void saveExternalTableWithSchemaAndQueryIt() { @Test public void saveTableAndQueryIt() { - Map options = new HashMap(); + Map options = new HashMap<>(); df.write() .format("org.apache.spark.sql.json") .mode(SaveMode.Append) diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index e0718f73aa13f..c5217149224e4 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -18,24 +18,22 @@ package org.apache.spark.streaming; import java.io.*; -import java.lang.Iterable; import java.nio.charset.Charset; import java.util.*; import java.util.concurrent.atomic.AtomicBoolean; +import scala.Tuple2; + +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; -import scala.Tuple2; - import org.junit.Assert; -import static org.junit.Assert.*; import org.junit.Test; import com.google.common.base.Optional; -import com.google.common.collect.Lists; import com.google.common.io.Files; import com.google.common.collect.Sets; @@ -54,14 +52,14 @@ // see http://stackoverflow.com/questions/758570/. public class JavaAPISuite extends LocalJavaStreamingContext implements Serializable { - public void equalIterator(Iterator a, Iterator b) { + public static void equalIterator(Iterator a, Iterator b) { while (a.hasNext() && b.hasNext()) { Assert.assertEquals(a.next(), b.next()); } Assert.assertEquals(a.hasNext(), b.hasNext()); } - public void equalIterable(Iterable a, Iterable b) { + public static void equalIterable(Iterable a, Iterable b) { equalIterator(a.iterator(), b.iterator()); } @@ -74,14 +72,14 @@ public void testInitialization() { @Test public void testContextState() { List> inputData = Arrays.asList(Arrays.asList(1, 2, 3, 4)); - Assert.assertTrue(ssc.getState() == StreamingContextState.INITIALIZED); + Assert.assertEquals(StreamingContextState.INITIALIZED, ssc.getState()); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaTestUtils.attachTestOutputStream(stream); - Assert.assertTrue(ssc.getState() == StreamingContextState.INITIALIZED); + Assert.assertEquals(StreamingContextState.INITIALIZED, ssc.getState()); ssc.start(); - Assert.assertTrue(ssc.getState() == StreamingContextState.ACTIVE); + Assert.assertEquals(StreamingContextState.ACTIVE, ssc.getState()); ssc.stop(); - Assert.assertTrue(ssc.getState() == StreamingContextState.STOPPED); + Assert.assertEquals(StreamingContextState.STOPPED, ssc.getState()); } @SuppressWarnings("unchecked") @@ -118,7 +116,7 @@ public void testMap() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream letterCount = stream.map(new Function() { @Override - public Integer call(String s) throws Exception { + public Integer call(String s) { return s.length(); } }); @@ -180,7 +178,7 @@ public void testWindowWithSlideDuration() { public void testFilter() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List> expected = Arrays.asList( Arrays.asList("giants"), @@ -189,7 +187,7 @@ public void testFilter() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream filtered = stream.filter(new Function() { @Override - public Boolean call(String s) throws Exception { + public Boolean call(String s) { return s.contains("a"); } }); @@ -243,11 +241,11 @@ public void testRepartitionFewerPartitions() { public void testGlom() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List>> expected = Arrays.asList( Arrays.asList(Arrays.asList("giants", "dodgers")), - Arrays.asList(Arrays.asList("yankees", "red socks"))); + Arrays.asList(Arrays.asList("yankees", "red sox"))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream> glommed = stream.glom(); @@ -262,22 +260,22 @@ public void testGlom() { public void testMapPartitions() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List> expected = Arrays.asList( Arrays.asList("GIANTSDODGERS"), - Arrays.asList("YANKEESRED SOCKS")); + Arrays.asList("YANKEESRED SOX")); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream mapped = stream.mapPartitions( new FlatMapFunction, String>() { @Override public Iterable call(Iterator in) { - String out = ""; + StringBuilder out = new StringBuilder(); while (in.hasNext()) { - out = out + in.next().toUpperCase(); + out.append(in.next().toUpperCase(Locale.ENGLISH)); } - return Lists.newArrayList(out); + return Arrays.asList(out.toString()); } }); JavaTestUtils.attachTestOutputStream(mapped); @@ -286,16 +284,16 @@ public Iterable call(Iterator in) { Assert.assertEquals(expected, result); } - private class IntegerSum implements Function2 { + private static class IntegerSum implements Function2 { @Override - public Integer call(Integer i1, Integer i2) throws Exception { + public Integer call(Integer i1, Integer i2) { return i1 + i2; } } - private class IntegerDifference implements Function2 { + private static class IntegerDifference implements Function2 { @Override - public Integer call(Integer i1, Integer i2) throws Exception { + public Integer call(Integer i1, Integer i2) { return i1 - i2; } } @@ -347,13 +345,13 @@ private void testReduceByWindow(boolean withInverse) { Arrays.asList(24)); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream reducedWindowed = null; + JavaDStream reducedWindowed; if (withInverse) { reducedWindowed = stream.reduceByWindow(new IntegerSum(), - new IntegerDifference(), new Duration(2000), new Duration(1000)); + new IntegerDifference(), new Duration(2000), new Duration(1000)); } else { reducedWindowed = stream.reduceByWindow(new IntegerSum(), - new Duration(2000), new Duration(1000)); + new Duration(2000), new Duration(1000)); } JavaTestUtils.attachTestOutputStream(reducedWindowed); List> result = JavaTestUtils.runStreams(ssc, 4, 4); @@ -378,11 +376,11 @@ public void testQueueStream() { Arrays.asList(7,8,9)); JavaSparkContext jsc = new JavaSparkContext(ssc.ssc().sc()); - JavaRDD rdd1 = ssc.sparkContext().parallelize(Arrays.asList(1, 2, 3)); - JavaRDD rdd2 = ssc.sparkContext().parallelize(Arrays.asList(4, 5, 6)); - JavaRDD rdd3 = ssc.sparkContext().parallelize(Arrays.asList(7,8,9)); + JavaRDD rdd1 = jsc.parallelize(Arrays.asList(1, 2, 3)); + JavaRDD rdd2 = jsc.parallelize(Arrays.asList(4, 5, 6)); + JavaRDD rdd3 = jsc.parallelize(Arrays.asList(7,8,9)); - LinkedList> rdds = Lists.newLinkedList(); + Queue> rdds = new LinkedList<>(); rdds.add(rdd1); rdds.add(rdd2); rdds.add(rdd3); @@ -410,10 +408,10 @@ public void testTransform() { JavaDStream transformed = stream.transform( new Function, JavaRDD>() { @Override - public JavaRDD call(JavaRDD in) throws Exception { + public JavaRDD call(JavaRDD in) { return in.map(new Function() { @Override - public Integer call(Integer i) throws Exception { + public Integer call(Integer i) { return i + 2; } }); @@ -435,70 +433,70 @@ public void testVariousTransform() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); List>> pairInputData = - Arrays.asList(Arrays.asList(new Tuple2("x", 1))); + Arrays.asList(Arrays.asList(new Tuple2<>("x", 1))); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairInputData, 1)); - JavaDStream transformed1 = stream.transform( + stream.transform( new Function, JavaRDD>() { @Override - public JavaRDD call(JavaRDD in) throws Exception { + public JavaRDD call(JavaRDD in) { return null; } } ); - JavaDStream transformed2 = stream.transform( + stream.transform( new Function2, Time, JavaRDD>() { - @Override public JavaRDD call(JavaRDD in, Time time) throws Exception { + @Override public JavaRDD call(JavaRDD in, Time time) { return null; } } ); - JavaPairDStream transformed3 = stream.transformToPair( + stream.transformToPair( new Function, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaRDD in) throws Exception { + @Override public JavaPairRDD call(JavaRDD in) { return null; } } ); - JavaPairDStream transformed4 = stream.transformToPair( + stream.transformToPair( new Function2, Time, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaRDD in, Time time) throws Exception { + @Override public JavaPairRDD call(JavaRDD in, Time time) { return null; } } ); - JavaDStream pairTransformed1 = pairStream.transform( + pairStream.transform( new Function, JavaRDD>() { - @Override public JavaRDD call(JavaPairRDD in) throws Exception { + @Override public JavaRDD call(JavaPairRDD in) { return null; } } ); - JavaDStream pairTransformed2 = pairStream.transform( + pairStream.transform( new Function2, Time, JavaRDD>() { - @Override public JavaRDD call(JavaPairRDD in, Time time) throws Exception { + @Override public JavaRDD call(JavaPairRDD in, Time time) { return null; } } ); - JavaPairDStream pairTransformed3 = pairStream.transformToPair( + pairStream.transformToPair( new Function, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaPairRDD in) throws Exception { + @Override public JavaPairRDD call(JavaPairRDD in) { return null; } } ); - JavaPairDStream pairTransformed4 = pairStream.transformToPair( + pairStream.transformToPair( new Function2, Time, JavaPairRDD>() { - @Override public JavaPairRDD call(JavaPairRDD in, Time time) throws Exception { + @Override public JavaPairRDD call(JavaPairRDD in, Time time) { return null; } } @@ -511,32 +509,32 @@ public JavaRDD call(JavaRDD in) throws Exception { public void testTransformWith() { List>> stringStringKVStream1 = Arrays.asList( Arrays.asList( - new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), + new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), Arrays.asList( - new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); + new Tuple2<>("california", "sharks"), + new Tuple2<>("new york", "rangers"))); List>> stringStringKVStream2 = Arrays.asList( Arrays.asList( - new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), + new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "mets")), Arrays.asList( - new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); + new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "islanders"))); List>>> expected = Arrays.asList( Sets.newHashSet( - new Tuple2>("california", - new Tuple2("dodgers", "giants")), - new Tuple2>("new york", - new Tuple2("yankees", "mets"))), + new Tuple2<>("california", + new Tuple2<>("dodgers", "giants")), + new Tuple2<>("new york", + new Tuple2<>("yankees", "mets"))), Sets.newHashSet( - new Tuple2>("california", - new Tuple2("sharks", "ducks")), - new Tuple2>("new york", - new Tuple2("rangers", "islanders")))); + new Tuple2<>("california", + new Tuple2<>("sharks", "ducks")), + new Tuple2<>("new york", + new Tuple2<>("rangers", "islanders")))); JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( ssc, stringStringKVStream1, 1); @@ -552,14 +550,12 @@ public void testTransformWith() { JavaPairRDD, JavaPairRDD, Time, - JavaPairRDD> - >() { + JavaPairRDD>>() { @Override public JavaPairRDD> call( JavaPairRDD rdd1, JavaPairRDD rdd2, - Time time - ) throws Exception { + Time time) { return rdd1.join(rdd2); } } @@ -567,9 +563,9 @@ public JavaPairRDD> call( JavaTestUtils.attachTestOutputStream(joined); List>>> result = JavaTestUtils.runStreams(ssc, 2, 2); - List>>> unorderedResult = Lists.newArrayList(); + List>>> unorderedResult = new ArrayList<>(); for (List>> res: result) { - unorderedResult.add(Sets.newHashSet(res)); + unorderedResult.add(Sets.newHashSet(res)); } Assert.assertEquals(expected, unorderedResult); @@ -587,89 +583,89 @@ public void testVariousTransformWith() { JavaDStream stream2 = JavaTestUtils.attachTestInputStream(ssc, inputData2, 1); List>> pairInputData1 = - Arrays.asList(Arrays.asList(new Tuple2("x", 1))); + Arrays.asList(Arrays.asList(new Tuple2<>("x", 1))); List>> pairInputData2 = - Arrays.asList(Arrays.asList(new Tuple2(1.0, 'x'))); + Arrays.asList(Arrays.asList(new Tuple2<>(1.0, 'x'))); JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairInputData1, 1)); JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairInputData2, 1)); - JavaDStream transformed1 = stream1.transformWith( + stream1.transformWith( stream2, new Function3, JavaRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaDStream transformed2 = stream1.transformWith( + stream1.transformWith( pairStream1, new Function3, JavaPairRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } ); - JavaPairDStream transformed3 = stream1.transformWithToPair( + stream1.transformWithToPair( stream2, new Function3, JavaRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaPairDStream transformed4 = stream1.transformWithToPair( + stream1.transformWithToPair( pairStream1, new Function3, JavaPairRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } ); - JavaDStream pairTransformed1 = pairStream1.transformWith( + pairStream1.transformWith( stream2, new Function3, JavaRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaDStream pairTransformed2_ = pairStream1.transformWith( + pairStream1.transformWith( pairStream1, new Function3, JavaPairRDD, Time, JavaRDD>() { @Override - public JavaRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } ); - JavaPairDStream pairTransformed3 = pairStream1.transformWithToPair( + pairStream1.transformWithToPair( stream2, new Function3, JavaRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaPairRDD rdd1, JavaRDD rdd2, Time time) { return null; } } ); - JavaPairDStream pairTransformed4 = pairStream1.transformWithToPair( + pairStream1.transformWithToPair( pairStream2, new Function3, JavaPairRDD, Time, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) throws Exception { + public JavaPairRDD call(JavaPairRDD rdd1, JavaPairRDD rdd2, Time time) { return null; } } @@ -690,13 +686,13 @@ public void testStreamingContextTransform(){ ); List>> pairStream1input = Arrays.asList( - Arrays.asList(new Tuple2(1, "x")), - Arrays.asList(new Tuple2(2, "y")) + Arrays.asList(new Tuple2<>(1, "x")), + Arrays.asList(new Tuple2<>(2, "y")) ); List>>> expected = Arrays.asList( - Arrays.asList(new Tuple2>(1, new Tuple2(1, "x"))), - Arrays.asList(new Tuple2>(2, new Tuple2(2, "y"))) + Arrays.asList(new Tuple2<>(1, new Tuple2<>(1, "x"))), + Arrays.asList(new Tuple2<>(2, new Tuple2<>(2, "y"))) ); JavaDStream stream1 = JavaTestUtils.attachTestInputStream(ssc, stream1input, 1); @@ -707,7 +703,7 @@ public void testStreamingContextTransform(){ List> listOfDStreams1 = Arrays.>asList(stream1, stream2); // This is just to test whether this transform to JavaStream compiles - JavaDStream transformed1 = ssc.transform( + ssc.transform( listOfDStreams1, new Function2>, Time, JavaRDD>() { @Override @@ -733,8 +729,8 @@ public JavaPairRDD> call(List> listO JavaPairRDD prdd3 = JavaPairRDD.fromJavaRDD(rdd3); PairFunction mapToTuple = new PairFunction() { @Override - public Tuple2 call(Integer i) throws Exception { - return new Tuple2(i, i); + public Tuple2 call(Integer i) { + return new Tuple2<>(i, i); } }; return rdd1.union(rdd2).mapToPair(mapToTuple).join(prdd3); @@ -763,7 +759,7 @@ public void testFlatMap() { JavaDStream flatMapped = stream.flatMap(new FlatMapFunction() { @Override public Iterable call(String x) { - return Lists.newArrayList(x.split("(?!^)")); + return Arrays.asList(x.split("(?!^)")); } }); JavaTestUtils.attachTestOutputStream(flatMapped); @@ -782,39 +778,39 @@ public void testPairFlatMap() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(6, "g"), - new Tuple2(6, "i"), - new Tuple2(6, "a"), - new Tuple2(6, "n"), - new Tuple2(6, "t"), - new Tuple2(6, "s")), + new Tuple2<>(6, "g"), + new Tuple2<>(6, "i"), + new Tuple2<>(6, "a"), + new Tuple2<>(6, "n"), + new Tuple2<>(6, "t"), + new Tuple2<>(6, "s")), Arrays.asList( - new Tuple2(7, "d"), - new Tuple2(7, "o"), - new Tuple2(7, "d"), - new Tuple2(7, "g"), - new Tuple2(7, "e"), - new Tuple2(7, "r"), - new Tuple2(7, "s")), + new Tuple2<>(7, "d"), + new Tuple2<>(7, "o"), + new Tuple2<>(7, "d"), + new Tuple2<>(7, "g"), + new Tuple2<>(7, "e"), + new Tuple2<>(7, "r"), + new Tuple2<>(7, "s")), Arrays.asList( - new Tuple2(9, "a"), - new Tuple2(9, "t"), - new Tuple2(9, "h"), - new Tuple2(9, "l"), - new Tuple2(9, "e"), - new Tuple2(9, "t"), - new Tuple2(9, "i"), - new Tuple2(9, "c"), - new Tuple2(9, "s"))); + new Tuple2<>(9, "a"), + new Tuple2<>(9, "t"), + new Tuple2<>(9, "h"), + new Tuple2<>(9, "l"), + new Tuple2<>(9, "e"), + new Tuple2<>(9, "t"), + new Tuple2<>(9, "i"), + new Tuple2<>(9, "c"), + new Tuple2<>(9, "s"))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream flatMapped = stream.flatMapToPair( new PairFlatMapFunction() { @Override - public Iterable> call(String in) throws Exception { - List> out = Lists.newArrayList(); + public Iterable> call(String in) { + List> out = new ArrayList<>(); for (String letter: in.split("(?!^)")) { - out.add(new Tuple2(in.length(), letter)); + out.add(new Tuple2<>(in.length(), letter)); } return out; } @@ -859,13 +855,13 @@ public void testUnion() { */ public static void assertOrderInvariantEquals( List> expected, List> actual) { - List> expectedSets = new ArrayList>(); + List> expectedSets = new ArrayList<>(); for (List list: expected) { - expectedSets.add(Collections.unmodifiableSet(new HashSet(list))); + expectedSets.add(Collections.unmodifiableSet(new HashSet<>(list))); } - List> actualSets = new ArrayList>(); + List> actualSets = new ArrayList<>(); for (List list: actual) { - actualSets.add(Collections.unmodifiableSet(new HashSet(list))); + actualSets.add(Collections.unmodifiableSet(new HashSet<>(list))); } Assert.assertEquals(expectedSets, actualSets); } @@ -877,25 +873,25 @@ public static void assertOrderInvariantEquals( public void testPairFilter() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("giants", 6)), - Arrays.asList(new Tuple2("yankees", 7))); + Arrays.asList(new Tuple2<>("giants", 6)), + Arrays.asList(new Tuple2<>("yankees", 7))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = stream.mapToPair( new PairFunction() { @Override - public Tuple2 call(String in) throws Exception { - return new Tuple2(in, in.length()); + public Tuple2 call(String in) { + return new Tuple2<>(in, in.length()); } }); JavaPairDStream filtered = pairStream.filter( new Function, Boolean>() { @Override - public Boolean call(Tuple2 in) throws Exception { + public Boolean call(Tuple2 in) { return in._1().contains("a"); } }); @@ -906,28 +902,28 @@ public Boolean call(Tuple2 in) throws Exception { } @SuppressWarnings("unchecked") - private List>> stringStringKVStream = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("california", "giants"), - new Tuple2("new york", "yankees"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("california", "ducks"), - new Tuple2("new york", "rangers"), - new Tuple2("new york", "islanders"))); + private final List>> stringStringKVStream = Arrays.asList( + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "yankees"), + new Tuple2<>("new york", "mets")), + Arrays.asList(new Tuple2<>("california", "sharks"), + new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "rangers"), + new Tuple2<>("new york", "islanders"))); @SuppressWarnings("unchecked") - private List>> stringIntKVStream = Arrays.asList( + private final List>> stringIntKVStream = Arrays.asList( Arrays.asList( - new Tuple2("california", 1), - new Tuple2("california", 3), - new Tuple2("new york", 4), - new Tuple2("new york", 1)), + new Tuple2<>("california", 1), + new Tuple2<>("california", 3), + new Tuple2<>("new york", 4), + new Tuple2<>("new york", 1)), Arrays.asList( - new Tuple2("california", 5), - new Tuple2("california", 5), - new Tuple2("new york", 3), - new Tuple2("new york", 1))); + new Tuple2<>("california", 5), + new Tuple2<>("california", 5), + new Tuple2<>("new york", 3), + new Tuple2<>("new york", 1))); @SuppressWarnings("unchecked") @Test @@ -936,22 +932,22 @@ public void testPairMap() { // Maps pair -> pair of different type List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, "california"), - new Tuple2(3, "california"), - new Tuple2(4, "new york"), - new Tuple2(1, "new york")), + new Tuple2<>(1, "california"), + new Tuple2<>(3, "california"), + new Tuple2<>(4, "new york"), + new Tuple2<>(1, "new york")), Arrays.asList( - new Tuple2(5, "california"), - new Tuple2(5, "california"), - new Tuple2(3, "new york"), - new Tuple2(1, "new york"))); + new Tuple2<>(5, "california"), + new Tuple2<>(5, "california"), + new Tuple2<>(3, "new york"), + new Tuple2<>(1, "new york"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reversed = pairStream.mapToPair( new PairFunction, Integer, String>() { @Override - public Tuple2 call(Tuple2 in) throws Exception { + public Tuple2 call(Tuple2 in) { return in.swap(); } }); @@ -969,23 +965,23 @@ public void testPairMapPartitions() { // Maps pair -> pair of different type List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, "california"), - new Tuple2(3, "california"), - new Tuple2(4, "new york"), - new Tuple2(1, "new york")), + new Tuple2<>(1, "california"), + new Tuple2<>(3, "california"), + new Tuple2<>(4, "new york"), + new Tuple2<>(1, "new york")), Arrays.asList( - new Tuple2(5, "california"), - new Tuple2(5, "california"), - new Tuple2(3, "new york"), - new Tuple2(1, "new york"))); + new Tuple2<>(5, "california"), + new Tuple2<>(5, "california"), + new Tuple2<>(3, "new york"), + new Tuple2<>(1, "new york"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reversed = pairStream.mapPartitionsToPair( new PairFlatMapFunction>, Integer, String>() { @Override - public Iterable> call(Iterator> in) throws Exception { - LinkedList> out = new LinkedList>(); + public Iterable> call(Iterator> in) { + List> out = new LinkedList<>(); while (in.hasNext()) { Tuple2 next = in.next(); out.add(next.swap()); @@ -1014,7 +1010,7 @@ public void testPairMap2() { // Maps pair -> single JavaDStream reversed = pairStream.map( new Function, Integer>() { @Override - public Integer call(Tuple2 in) throws Exception { + public Integer call(Tuple2 in) { return in._2(); } }); @@ -1030,23 +1026,23 @@ public Integer call(Tuple2 in) throws Exception { public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair List>> inputData = Arrays.asList( Arrays.asList( - new Tuple2("hi", 1), - new Tuple2("ho", 2)), + new Tuple2<>("hi", 1), + new Tuple2<>("ho", 2)), Arrays.asList( - new Tuple2("hi", 1), - new Tuple2("ho", 2))); + new Tuple2<>("hi", 1), + new Tuple2<>("ho", 2))); List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, "h"), - new Tuple2(1, "i"), - new Tuple2(2, "h"), - new Tuple2(2, "o")), + new Tuple2<>(1, "h"), + new Tuple2<>(1, "i"), + new Tuple2<>(2, "h"), + new Tuple2<>(2, "o")), Arrays.asList( - new Tuple2(1, "h"), - new Tuple2(1, "i"), - new Tuple2(2, "h"), - new Tuple2(2, "o"))); + new Tuple2<>(1, "h"), + new Tuple2<>(1, "i"), + new Tuple2<>(2, "h"), + new Tuple2<>(2, "o"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); @@ -1054,10 +1050,10 @@ public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair JavaPairDStream flatMapped = pairStream.flatMapToPair( new PairFlatMapFunction, Integer, String>() { @Override - public Iterable> call(Tuple2 in) throws Exception { - List> out = new LinkedList>(); + public Iterable> call(Tuple2 in) { + List> out = new LinkedList<>(); for (Character s : in._1().toCharArray()) { - out.add(new Tuple2(in._2(), s.toString())); + out.add(new Tuple2<>(in._2(), s.toString())); } return out; } @@ -1075,11 +1071,11 @@ public void testPairGroupByKey() { List>>> expected = Arrays.asList( Arrays.asList( - new Tuple2>("california", Arrays.asList("dodgers", "giants")), - new Tuple2>("new york", Arrays.asList("yankees", "mets"))), + new Tuple2<>("california", Arrays.asList("dodgers", "giants")), + new Tuple2<>("new york", Arrays.asList("yankees", "mets"))), Arrays.asList( - new Tuple2>("california", Arrays.asList("sharks", "ducks")), - new Tuple2>("new york", Arrays.asList("rangers", "islanders")))); + new Tuple2<>("california", Arrays.asList("sharks", "ducks")), + new Tuple2<>("new york", Arrays.asList("rangers", "islanders")))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1111,11 +1107,11 @@ public void testPairReduceByKey() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2("california", 4), - new Tuple2("new york", 5)), + new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), Arrays.asList( - new Tuple2("california", 10), - new Tuple2("new york", 4))); + new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1136,20 +1132,20 @@ public void testCombineByKey() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2("california", 4), - new Tuple2("new york", 5)), + new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), Arrays.asList( - new Tuple2("california", 10), - new Tuple2("new york", 4))); + new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream combined = pairStream.combineByKey( + JavaPairDStream combined = pairStream.combineByKey( new Function() { @Override - public Integer call(Integer i) throws Exception { + public Integer call(Integer i) { return i; } }, new IntegerSum(), new IntegerSum(), new HashPartitioner(2)); @@ -1170,13 +1166,13 @@ public void testCountByValue() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("world", 1L)), + new Tuple2<>("hello", 1L), + new Tuple2<>("world", 1L)), Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("moon", 1L)), + new Tuple2<>("hello", 1L), + new Tuple2<>("moon", 1L)), Arrays.asList( - new Tuple2("hello", 1L))); + new Tuple2<>("hello", 1L))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream counted = stream.countByValue(); @@ -1193,16 +1189,16 @@ public void testGroupByKeyAndWindow() { List>>> expected = Arrays.asList( Arrays.asList( - new Tuple2>("california", Arrays.asList(1, 3)), - new Tuple2>("new york", Arrays.asList(1, 4)) + new Tuple2<>("california", Arrays.asList(1, 3)), + new Tuple2<>("new york", Arrays.asList(1, 4)) ), Arrays.asList( - new Tuple2>("california", Arrays.asList(1, 3, 5, 5)), - new Tuple2>("new york", Arrays.asList(1, 1, 3, 4)) + new Tuple2<>("california", Arrays.asList(1, 3, 5, 5)), + new Tuple2<>("new york", Arrays.asList(1, 1, 3, 4)) ), Arrays.asList( - new Tuple2>("california", Arrays.asList(5, 5)), - new Tuple2>("new york", Arrays.asList(1, 3)) + new Tuple2<>("california", Arrays.asList(5, 5)), + new Tuple2<>("new york", Arrays.asList(1, 3)) ) ); @@ -1220,16 +1216,16 @@ public void testGroupByKeyAndWindow() { } } - private HashSet>> convert(List>> listOfTuples) { - List>> newListOfTuples = new ArrayList>>(); + private static Set>> convert(List>> listOfTuples) { + List>> newListOfTuples = new ArrayList<>(); for (Tuple2> tuple: listOfTuples) { newListOfTuples.add(convert(tuple)); } - return new HashSet>>(newListOfTuples); + return new HashSet<>(newListOfTuples); } - private Tuple2> convert(Tuple2> tuple) { - return new Tuple2>(tuple._1(), new HashSet(tuple._2())); + private static Tuple2> convert(Tuple2> tuple) { + return new Tuple2<>(tuple._1(), new HashSet<>(tuple._2())); } @SuppressWarnings("unchecked") @@ -1238,12 +1234,12 @@ public void testReduceByKeyAndWindow() { List>> inputData = stringIntKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 10), - new Tuple2("new york", 4))); + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1262,12 +1258,12 @@ public void testUpdateStateByKey() { List>> inputData = stringIntKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9))); + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1278,10 +1274,10 @@ public void testUpdateStateByKey() { public Optional call(List values, Optional state) { int out = 0; if (state.isPresent()) { - out = out + state.get(); + out += state.get(); } for (Integer v : values) { - out = out + v; + out += v; } return Optional.of(out); } @@ -1298,19 +1294,19 @@ public void testUpdateStateByKeyWithInitial() { List>> inputData = stringIntKVStream; List> initial = Arrays.asList ( - new Tuple2 ("california", 1), - new Tuple2 ("new york", 2)); + new Tuple2<>("california", 1), + new Tuple2<>("new york", 2)); JavaRDD> tmpRDD = ssc.sparkContext().parallelize(initial); JavaPairRDD initialRDD = JavaPairRDD.fromJavaRDD (tmpRDD); List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 5), - new Tuple2("new york", 7)), - Arrays.asList(new Tuple2("california", 15), - new Tuple2("new york", 11)), - Arrays.asList(new Tuple2("california", 15), - new Tuple2("new york", 11))); + Arrays.asList(new Tuple2<>("california", 5), + new Tuple2<>("new york", 7)), + Arrays.asList(new Tuple2<>("california", 15), + new Tuple2<>("new york", 11)), + Arrays.asList(new Tuple2<>("california", 15), + new Tuple2<>("new york", 11))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); @@ -1321,10 +1317,10 @@ public void testUpdateStateByKeyWithInitial() { public Optional call(List values, Optional state) { int out = 0; if (state.isPresent()) { - out = out + state.get(); + out += state.get(); } for (Integer v : values) { - out = out + v; + out += v; } return Optional.of(out); } @@ -1341,19 +1337,19 @@ public void testReduceByKeyAndWindowWithInverse() { List>> inputData = stringIntKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 10), - new Tuple2("new york", 4))); + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reduceWindowed = pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), - new Duration(2000), new Duration(1000)); + new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(reduceWindowed); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1370,15 +1366,15 @@ public void testCountByValueAndWindow() { List>> expected = Arrays.asList( Sets.newHashSet( - new Tuple2("hello", 1L), - new Tuple2("world", 1L)), + new Tuple2<>("hello", 1L), + new Tuple2<>("world", 1L)), Sets.newHashSet( - new Tuple2("hello", 2L), - new Tuple2("world", 1L), - new Tuple2("moon", 1L)), + new Tuple2<>("hello", 2L), + new Tuple2<>("world", 1L), + new Tuple2<>("moon", 1L)), Sets.newHashSet( - new Tuple2("hello", 2L), - new Tuple2("moon", 1L))); + new Tuple2<>("hello", 2L), + new Tuple2<>("moon", 1L))); JavaDStream stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1386,7 +1382,7 @@ public void testCountByValueAndWindow() { stream.countByValueAndWindow(new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(counted); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); - List>> unorderedResult = Lists.newArrayList(); + List>> unorderedResult = new ArrayList<>(); for (List> res: result) { unorderedResult.add(Sets.newHashSet(res)); } @@ -1399,27 +1395,27 @@ public void testCountByValueAndWindow() { public void testPairTransform() { List>> inputData = Arrays.asList( Arrays.asList( - new Tuple2(3, 5), - new Tuple2(1, 5), - new Tuple2(4, 5), - new Tuple2(2, 5)), + new Tuple2<>(3, 5), + new Tuple2<>(1, 5), + new Tuple2<>(4, 5), + new Tuple2<>(2, 5)), Arrays.asList( - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5), - new Tuple2(1, 5))); + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5), + new Tuple2<>(1, 5))); List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, 5), - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5)), + new Tuple2<>(1, 5), + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5)), Arrays.asList( - new Tuple2(1, 5), - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5))); + new Tuple2<>(1, 5), + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1428,7 +1424,7 @@ public void testPairTransform() { JavaPairDStream sorted = pairStream.transformToPair( new Function, JavaPairRDD>() { @Override - public JavaPairRDD call(JavaPairRDD in) throws Exception { + public JavaPairRDD call(JavaPairRDD in) { return in.sortByKey(); } }); @@ -1444,15 +1440,15 @@ public JavaPairRDD call(JavaPairRDD in) thro public void testPairToNormalRDDTransform() { List>> inputData = Arrays.asList( Arrays.asList( - new Tuple2(3, 5), - new Tuple2(1, 5), - new Tuple2(4, 5), - new Tuple2(2, 5)), + new Tuple2<>(3, 5), + new Tuple2<>(1, 5), + new Tuple2<>(4, 5), + new Tuple2<>(2, 5)), Arrays.asList( - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5), - new Tuple2(1, 5))); + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5), + new Tuple2<>(1, 5))); List> expected = Arrays.asList( Arrays.asList(3,1,4,2), @@ -1465,11 +1461,11 @@ public void testPairToNormalRDDTransform() { JavaDStream firstParts = pairStream.transform( new Function, JavaRDD>() { @Override - public JavaRDD call(JavaPairRDD in) throws Exception { + public JavaRDD call(JavaPairRDD in) { return in.map(new Function, Integer>() { @Override - public Integer call(Tuple2 in) { - return in._1(); + public Integer call(Tuple2 in2) { + return in2._1(); } }); } @@ -1487,14 +1483,14 @@ public void testMapValues() { List>> inputData = stringStringKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", "DODGERS"), - new Tuple2("california", "GIANTS"), - new Tuple2("new york", "YANKEES"), - new Tuple2("new york", "METS")), - Arrays.asList(new Tuple2("california", "SHARKS"), - new Tuple2("california", "DUCKS"), - new Tuple2("new york", "RANGERS"), - new Tuple2("new york", "ISLANDERS"))); + Arrays.asList(new Tuple2<>("california", "DODGERS"), + new Tuple2<>("california", "GIANTS"), + new Tuple2<>("new york", "YANKEES"), + new Tuple2<>("new york", "METS")), + Arrays.asList(new Tuple2<>("california", "SHARKS"), + new Tuple2<>("california", "DUCKS"), + new Tuple2<>("new york", "RANGERS"), + new Tuple2<>("new york", "ISLANDERS"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1502,8 +1498,8 @@ public void testMapValues() { JavaPairDStream mapped = pairStream.mapValues(new Function() { @Override - public String call(String s) throws Exception { - return s.toUpperCase(); + public String call(String s) { + return s.toUpperCase(Locale.ENGLISH); } }); @@ -1519,22 +1515,22 @@ public void testFlatMapValues() { List>> inputData = stringStringKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers1"), - new Tuple2("california", "dodgers2"), - new Tuple2("california", "giants1"), - new Tuple2("california", "giants2"), - new Tuple2("new york", "yankees1"), - new Tuple2("new york", "yankees2"), - new Tuple2("new york", "mets1"), - new Tuple2("new york", "mets2")), - Arrays.asList(new Tuple2("california", "sharks1"), - new Tuple2("california", "sharks2"), - new Tuple2("california", "ducks1"), - new Tuple2("california", "ducks2"), - new Tuple2("new york", "rangers1"), - new Tuple2("new york", "rangers2"), - new Tuple2("new york", "islanders1"), - new Tuple2("new york", "islanders2"))); + Arrays.asList(new Tuple2<>("california", "dodgers1"), + new Tuple2<>("california", "dodgers2"), + new Tuple2<>("california", "giants1"), + new Tuple2<>("california", "giants2"), + new Tuple2<>("new york", "yankees1"), + new Tuple2<>("new york", "yankees2"), + new Tuple2<>("new york", "mets1"), + new Tuple2<>("new york", "mets2")), + Arrays.asList(new Tuple2<>("california", "sharks1"), + new Tuple2<>("california", "sharks2"), + new Tuple2<>("california", "ducks1"), + new Tuple2<>("california", "ducks2"), + new Tuple2<>("new york", "rangers1"), + new Tuple2<>("new york", "rangers2"), + new Tuple2<>("new york", "islanders1"), + new Tuple2<>("new york", "islanders2"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -1545,7 +1541,7 @@ public void testFlatMapValues() { new Function>() { @Override public Iterable call(String in) { - List out = new ArrayList(); + List out = new ArrayList<>(); out.add(in + "1"); out.add(in + "2"); return out; @@ -1562,29 +1558,29 @@ public Iterable call(String in) { @Test public void testCoGroup() { List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), + Arrays.asList(new Tuple2<>("california", "sharks"), + new Tuple2<>("new york", "rangers"))); List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); + Arrays.asList(new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "mets")), + Arrays.asList(new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "islanders"))); List, List>>>> expected = Arrays.asList( Arrays.asList( - new Tuple2, List>>("california", - new Tuple2, List>(Arrays.asList("dodgers"), Arrays.asList("giants"))), - new Tuple2, List>>("new york", - new Tuple2, List>(Arrays.asList("yankees"), Arrays.asList("mets")))), + new Tuple2<>("california", + new Tuple2<>(Arrays.asList("dodgers"), Arrays.asList("giants"))), + new Tuple2<>("new york", + new Tuple2<>(Arrays.asList("yankees"), Arrays.asList("mets")))), Arrays.asList( - new Tuple2, List>>("california", - new Tuple2, List>(Arrays.asList("sharks"), Arrays.asList("ducks"))), - new Tuple2, List>>("new york", - new Tuple2, List>(Arrays.asList("rangers"), Arrays.asList("islanders"))))); + new Tuple2<>("california", + new Tuple2<>(Arrays.asList("sharks"), Arrays.asList("ducks"))), + new Tuple2<>("new york", + new Tuple2<>(Arrays.asList("rangers"), Arrays.asList("islanders"))))); JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( @@ -1620,29 +1616,29 @@ public void testCoGroup() { @Test public void testJoin() { List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), + Arrays.asList(new Tuple2<>("california", "sharks"), + new Tuple2<>("new york", "rangers"))); List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); + Arrays.asList(new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "mets")), + Arrays.asList(new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "islanders"))); List>>> expected = Arrays.asList( Arrays.asList( - new Tuple2>("california", - new Tuple2("dodgers", "giants")), - new Tuple2>("new york", - new Tuple2("yankees", "mets"))), + new Tuple2<>("california", + new Tuple2<>("dodgers", "giants")), + new Tuple2<>("new york", + new Tuple2<>("yankees", "mets"))), Arrays.asList( - new Tuple2>("california", - new Tuple2("sharks", "ducks")), - new Tuple2>("new york", - new Tuple2("rangers", "islanders")))); + new Tuple2<>("california", + new Tuple2<>("sharks", "ducks")), + new Tuple2<>("new york", + new Tuple2<>("rangers", "islanders")))); JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( @@ -1664,13 +1660,13 @@ public void testJoin() { @Test public void testLeftOuterJoin() { List>> stringStringKVStream1 = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), - Arrays.asList(new Tuple2("california", "sharks") )); + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), + Arrays.asList(new Tuple2<>("california", "sharks") )); List>> stringStringKVStream2 = Arrays.asList( - Arrays.asList(new Tuple2("california", "giants") ), - Arrays.asList(new Tuple2("new york", "islanders") ) + Arrays.asList(new Tuple2<>("california", "giants") ), + Arrays.asList(new Tuple2<>("new york", "islanders") ) ); @@ -1713,7 +1709,7 @@ public void testCheckpointMasterRecovery() throws InterruptedException { JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream letterCount = stream.map(new Function() { @Override - public Integer call(String s) throws Exception { + public Integer call(String s) { return s.length(); } }); @@ -1752,6 +1748,7 @@ public void testContextGetOrCreate() throws InterruptedException { // (used to detect the new context) final AtomicBoolean newContextCreated = new AtomicBoolean(false); Function0 creatingFunc = new Function0() { + @Override public JavaStreamingContext call() { newContextCreated.set(true); return new JavaStreamingContext(conf, Seconds.apply(1)); @@ -1765,20 +1762,20 @@ public JavaStreamingContext call() { newContextCreated.set(false); ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration(), true); + new Configuration(), true); Assert.assertTrue("new context not created", newContextCreated.get()); ssc.stop(); newContextCreated.set(false); ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration()); + new Configuration()); Assert.assertTrue("old context not recovered", !newContextCreated.get()); ssc.stop(); newContextCreated.set(false); JavaSparkContext sc = new JavaSparkContext(conf); ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration()); + new Configuration()); Assert.assertTrue("old context not recovered", !newContextCreated.get()); ssc.stop(); } @@ -1800,7 +1797,7 @@ public void testCheckpointofIndividualStream() throws InterruptedException { JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream letterCount = stream.map(new Function() { @Override - public Integer call(String s) throws Exception { + public Integer call(String s) { return s.length(); } }); @@ -1818,29 +1815,26 @@ public Integer call(String s) throws Exception { // InputStream functionality is deferred to the existing Scala tests. @Test public void testSocketTextStream() { - JavaReceiverInputDStream test = ssc.socketTextStream("localhost", 12345); + ssc.socketTextStream("localhost", 12345); } @Test public void testSocketString() { - - class Converter implements Function> { - public Iterable call(InputStream in) throws IOException { - BufferedReader reader = new BufferedReader(new InputStreamReader(in)); - List out = new ArrayList(); - while (true) { - String line = reader.readLine(); - if (line == null) { break; } - out.add(line); - } - return out; - } - } - - JavaDStream test = ssc.socketStream( + ssc.socketStream( "localhost", 12345, - new Converter(), + new Function>() { + @Override + public Iterable call(InputStream in) throws IOException { + List out = new ArrayList<>(); + try (BufferedReader reader = new BufferedReader(new InputStreamReader(in))) { + for (String line; (line = reader.readLine()) != null;) { + out.add(line); + } + } + return out; + } + }, StorageLevel.MEMORY_ONLY()); } @@ -1870,7 +1864,7 @@ public void testFileStream() throws IOException { TextInputFormat.class, new Function() { @Override - public Boolean call(Path v1) throws Exception { + public Boolean call(Path v1) { return Boolean.TRUE; } }, @@ -1879,7 +1873,7 @@ public Boolean call(Path v1) throws Exception { JavaDStream test = inputStream.map( new Function, String>() { @Override - public String call(Tuple2 v1) throws Exception { + public String call(Tuple2 v1) { return v1._2().toString(); } }); @@ -1892,19 +1886,15 @@ public String call(Tuple2 v1) throws Exception { @Test public void testRawSocketStream() { - JavaReceiverInputDStream test = ssc.rawSocketStream("localhost", 12345); + ssc.rawSocketStream("localhost", 12345); } - private List> fileTestPrepare(File testDir) throws IOException { + private static List> fileTestPrepare(File testDir) throws IOException { File existingFile = new File(testDir, "0"); Files.write("0\n", existingFile, Charset.forName("UTF-8")); - assertTrue(existingFile.setLastModified(1000) && existingFile.lastModified() == 1000); - - List> expected = Arrays.asList( - Arrays.asList("0") - ); - - return expected; + Assert.assertTrue(existingFile.setLastModified(1000)); + Assert.assertEquals(1000, existingFile.lastModified()); + return Arrays.asList(Arrays.asList("0")); } @SuppressWarnings("unchecked") diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java index 1b0787fe69dec..ec2bffd6a5b97 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java @@ -36,7 +36,6 @@ import java.io.Serializable; import java.net.ConnectException; import java.net.Socket; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; public class JavaReceiverAPISuite implements Serializable { @@ -64,16 +63,16 @@ public void testReceiver() throws InterruptedException { ssc.receiverStream(new JavaSocketReceiver("localhost", server.port())); JavaDStream mapped = input.map(new Function() { @Override - public String call(String v1) throws Exception { + public String call(String v1) { return v1 + "."; } }); mapped.foreachRDD(new Function, Void>() { @Override - public Void call(JavaRDD rdd) throws Exception { - long count = rdd.count(); - dataCounter.addAndGet(count); - return null; + public Void call(JavaRDD rdd) { + long count = rdd.count(); + dataCounter.addAndGet(count); + return null; } }); @@ -83,7 +82,7 @@ public Void call(JavaRDD rdd) throws Exception { Thread.sleep(200); for (int i = 0; i < 6; i++) { - server.send("" + i + "\n"); // \n to make sure these are separate lines + server.send(i + "\n"); // \n to make sure these are separate lines Thread.sleep(100); } while (dataCounter.get() == 0 && System.currentTimeMillis() - startTime < timeout) { @@ -95,50 +94,49 @@ public Void call(JavaRDD rdd) throws Exception { server.stop(); } } -} -class JavaSocketReceiver extends Receiver { + private static class JavaSocketReceiver extends Receiver { - String host = null; - int port = -1; + String host = null; + int port = -1; - public JavaSocketReceiver(String host_ , int port_) { - super(StorageLevel.MEMORY_AND_DISK()); - host = host_; - port = port_; - } + JavaSocketReceiver(String host_ , int port_) { + super(StorageLevel.MEMORY_AND_DISK()); + host = host_; + port = port_; + } - @Override - public void onStart() { - new Thread() { - @Override public void run() { - receive(); - } - }.start(); - } + @Override + public void onStart() { + new Thread() { + @Override public void run() { + receive(); + } + }.start(); + } - @Override - public void onStop() { - } + @Override + public void onStop() { + } - private void receive() { - Socket socket = null; - try { - socket = new Socket(host, port); - BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream())); - String userInput; - while ((userInput = in.readLine()) != null) { - store(userInput); + private void receive() { + try { + Socket socket = new Socket(host, port); + BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream())); + String userInput; + while ((userInput = in.readLine()) != null) { + store(userInput); + } + in.close(); + socket.close(); + } catch(ConnectException ce) { + ce.printStackTrace(); + restart("Could not connect", ce); + } catch(Throwable t) { + t.printStackTrace(); + restart("Error receiving data", t); } - in.close(); - socket.close(); - } catch(ConnectException ce) { - ce.printStackTrace(); - restart("Could not connect", ce); - } catch(Throwable t) { - t.printStackTrace(); - restart("Error receiving data", t); } } -} +} From f4a22808e03fa12bfe1bfc82cf713cfda7e063a9 Mon Sep 17 00:00:00 2001 From: JihongMa Date: Sat, 12 Sep 2015 10:17:15 -0700 Subject: [PATCH 268/802] [SPARK-6548] Adding stddev to DataFrame functions Adding STDDEV support for DataFrame using 1-pass online /parallel algorithm to compute variance. Please review the code change. Author: JihongMa Author: Jihong MA Author: Jihong MA Author: Jihong MA Closes #6297 from JihongMA/SPARK-SQL. --- R/pkg/inst/tests/test_sparkSQL.R | 2 +- python/pyspark/sql/dataframe.py | 36 +-- .../catalyst/analysis/FunctionRegistry.scala | 3 + .../catalyst/analysis/HiveTypeCoercion.scala | 3 + .../spark/sql/catalyst/dsl/package.scala | 3 + .../expressions/aggregate/functions.scala | 143 ++++++++++ .../expressions/aggregate/utils.scala | 18 ++ .../sql/catalyst/expressions/aggregates.scala | 245 ++++++++++++++++++ .../org/apache/spark/sql/DataFrame.scala | 6 +- .../org/apache/spark/sql/GroupedData.scala | 39 +++ .../org/apache/spark/sql/functions.scala | 27 ++ .../apache/spark/sql/JavaDataFrameSuite.java | 1 + .../spark/sql/DataFrameAggregateSuite.scala | 33 +++ .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 42 ++- .../execution/AggregationQuerySuite.scala | 35 --- 16 files changed, 574 insertions(+), 64 deletions(-) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 1ccfde59176f5..98d4402d368e1 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -1147,7 +1147,7 @@ test_that("describe() and summarize() on a DataFrame", { stats <- describe(df, "age") expect_equal(collect(stats)[1, "summary"], "count") expect_equal(collect(stats)[2, "age"], "24.5") - expect_equal(collect(stats)[3, "age"], "5.5") + expect_equal(collect(stats)[3, "age"], "7.7781745930520225") stats <- describe(df) expect_equal(collect(stats)[4, "name"], "Andy") expect_equal(collect(stats)[5, "age"], "30") diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index c5bf55791240b..fb995fa3a76b5 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -653,25 +653,25 @@ def describe(self, *cols): guarantee about the backward compatibility of the schema of the resulting DataFrame. >>> df.describe().show() - +-------+---+ - |summary|age| - +-------+---+ - | count| 2| - | mean|3.5| - | stddev|1.5| - | min| 2| - | max| 5| - +-------+---+ + +-------+------------------+ + |summary| age| + +-------+------------------+ + | count| 2| + | mean| 3.5| + | stddev|2.1213203435596424| + | min| 2| + | max| 5| + +-------+------------------+ >>> df.describe(['age', 'name']).show() - +-------+---+-----+ - |summary|age| name| - +-------+---+-----+ - | count| 2| 2| - | mean|3.5| null| - | stddev|1.5| null| - | min| 2|Alice| - | max| 5| Bob| - +-------+---+-----+ + +-------+------------------+-----+ + |summary| age| name| + +-------+------------------+-----+ + | count| 2| 2| + | mean| 3.5| null| + | stddev|2.1213203435596424| null| + | min| 2|Alice| + | max| 5| Bob| + +-------+------------------+-----+ """ if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index cd5a90d788151..11b4866bf264b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -168,6 +168,9 @@ object FunctionRegistry { expression[Last]("last"), expression[Max]("max"), expression[Min]("min"), + expression[Stddev]("stddev"), + expression[StddevPop]("stddev_pop"), + expression[StddevSamp]("stddev_samp"), expression[Sum]("sum"), // string functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 87c11abbad490..87a3845b2d9e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -297,6 +297,9 @@ object HiveTypeCoercion { case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) + case Stddev(e @ StringType()) => Stddev(Cast(e, DoubleType)) + case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) + case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index a7e3a49327655..699c4cc63d09a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -159,6 +159,9 @@ package object dsl { def lower(e: Expression): Expression = Lower(e) def sqrt(e: Expression): Expression = Sqrt(e) def abs(e: Expression): Expression = Abs(e) + def stddev(e: Expression): Expression = Stddev(e) + def stddev_pop(e: Expression): Expression = StddevPop(e) + def stddev_samp(e: Expression): Expression = StddevSamp(e) implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name } // TODO more implicit class for literal? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index a73024d6adba1..02cd0ac0db118 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -249,6 +249,149 @@ case class Min(child: Expression) extends AlgebraicAggregate { override val evaluateExpression = min } +// Compute the sample standard deviation of a column +case class Stddev(child: Expression) extends StddevAgg(child) { + + override def isSample: Boolean = true + override def prettyName: String = "stddev" +} + +// Compute the population standard deviation of a column +case class StddevPop(child: Expression) extends StddevAgg(child) { + + override def isSample: Boolean = false + override def prettyName: String = "stddev_pop" +} + +// Compute the sample standard deviation of a column +case class StddevSamp(child: Expression) extends StddevAgg(child) { + + override def isSample: Boolean = true + override def prettyName: String = "stddev_samp" +} + +// Compute standard deviation based on online algorithm specified here: +// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance +abstract class StddevAgg(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + def isSample: Boolean + + // Return data type. + override def dataType: DataType = resultType + + // Expected input data type. + // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the + // new version at planning time (after analysis phase). For now, NullType is added at here + // to make it resolved when we have cases like `select stddev(null)`. + // We can use our analyzer to cast NullType to the default data type of the NumericType once + // we remove the old aggregate functions. Then, we will not need NullType at here. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + + private val resultType = DoubleType + + private val preCount = AttributeReference("preCount", resultType)() + private val currentCount = AttributeReference("currentCount", resultType)() + private val preAvg = AttributeReference("preAvg", resultType)() + private val currentAvg = AttributeReference("currentAvg", resultType)() + private val currentMk = AttributeReference("currentMk", resultType)() + + override val bufferAttributes = preCount :: currentCount :: preAvg :: + currentAvg :: currentMk :: Nil + + override val initialValues = Seq( + /* preCount = */ Cast(Literal(0), resultType), + /* currentCount = */ Cast(Literal(0), resultType), + /* preAvg = */ Cast(Literal(0), resultType), + /* currentAvg = */ Cast(Literal(0), resultType), + /* currentMk = */ Cast(Literal(0), resultType) + ) + + override val updateExpressions = { + + // update average + // avg = avg + (value - avg)/count + def avgAdd: Expression = { + currentAvg + ((Cast(child, resultType) - currentAvg) / currentCount) + } + + // update sum of square of difference from mean + // Mk = Mk + (value - preAvg) * (value - updatedAvg) + def mkAdd: Expression = { + val delta1 = Cast(child, resultType) - preAvg + val delta2 = Cast(child, resultType) - currentAvg + currentMk + (delta1 * delta2) + } + + Seq( + /* preCount = */ If(IsNull(child), preCount, currentCount), + /* currentCount = */ If(IsNull(child), currentCount, + Add(currentCount, Cast(Literal(1), resultType))), + /* preAvg = */ If(IsNull(child), preAvg, currentAvg), + /* currentAvg = */ If(IsNull(child), currentAvg, avgAdd), + /* currentMk = */ If(IsNull(child), currentMk, mkAdd) + ) + } + + override val mergeExpressions = { + + // count merge + def countMerge: Expression = { + currentCount.left + currentCount.right + } + + // average merge + def avgMerge: Expression = { + ((currentAvg.left * preCount) + (currentAvg.right * currentCount.right)) / + (preCount + currentCount.right) + } + + // update sum of square differences + def mkMerge: Expression = { + val avgDelta = currentAvg.right - preAvg + val mkDelta = (avgDelta * avgDelta) * (preCount * currentCount.right) / + (preCount + currentCount.right) + + currentMk.left + currentMk.right + mkDelta + } + + Seq( + /* preCount = */ If(IsNull(currentCount.left), + Cast(Literal(0), resultType), currentCount.left), + /* currentCount = */ If(IsNull(currentCount.left), currentCount.right, + If(IsNull(currentCount.right), currentCount.left, countMerge)), + /* preAvg = */ If(IsNull(currentAvg.left), Cast(Literal(0), resultType), currentAvg.left), + /* currentAvg = */ If(IsNull(currentAvg.left), currentAvg.right, + If(IsNull(currentAvg.right), currentAvg.left, avgMerge)), + /* currentMk = */ If(IsNull(currentMk.left), currentMk.right, + If(IsNull(currentMk.right), currentMk.left, mkMerge)) + ) + } + + override val evaluateExpression = { + // when currentCount == 0, return null + // when currentCount == 1, return 0 + // when currentCount >1 + // stddev_samp = sqrt (currentMk/(currentCount -1)) + // stddev_pop = sqrt (currentMk/currentCount) + val varCol = { + if (isSample) { + currentMk / Cast((currentCount - Cast(Literal(1), resultType)), resultType) + } + else { + currentMk / currentCount + } + } + + If(EqualTo(currentCount, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), + If(EqualTo(currentCount, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), + Cast(Sqrt(varCol), resultType))) + } +} + case class Sum(child: Expression) extends AlgebraicAggregate { override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala index 4a43318a95490..ce3dddad87f55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -85,6 +85,24 @@ object Utils { mode = aggregate.Complete, isDistinct = false) + case expressions.Stddev(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Stddev(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.StddevPop(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.StddevPop(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.StddevSamp(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.StddevSamp(child), + mode = aggregate.Complete, + isDistinct = false) + case expressions.Sum(child) => aggregate.AggregateExpression2( aggregateFunction = aggregate.Sum(child), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 5e8298aaaa9cb..f1c47f39043c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -691,3 +691,248 @@ case class LastFunction(expr: Expression, base: AggregateExpression1) extends Ag result } } + +// Compute standard deviation based on online algorithm specified here: +// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance +abstract class StddevAgg1(child: Expression) extends UnaryExpression with PartialAggregate1 { + override def nullable: Boolean = true + override def dataType: DataType = DoubleType + + def isSample: Boolean + + override def asPartial: SplitEvaluation = { + val partialStd = Alias(ComputePartialStd(child), "PartialStddev")() + SplitEvaluation(MergePartialStd(partialStd.toAttribute, isSample), partialStd :: Nil) + } + + override def newInstance(): StddevFunction = new StddevFunction(child, this, isSample) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function stddev") + +} + +// Compute the sample standard deviation of a column +case class Stddev(child: Expression) extends StddevAgg1(child) { + + override def toString: String = s"STDDEV($child)" + override def isSample: Boolean = true +} + +// Compute the population standard deviation of a column +case class StddevPop(child: Expression) extends StddevAgg1(child) { + + override def toString: String = s"STDDEV_POP($child)" + override def isSample: Boolean = false +} + +// Compute the sample standard deviation of a column +case class StddevSamp(child: Expression) extends StddevAgg1(child) { + + override def toString: String = s"STDDEV_SAMP($child)" + override def isSample: Boolean = true +} + +case class ComputePartialStd(child: Expression) extends UnaryExpression with AggregateExpression1 { + def this() = this(null) + + override def children: Seq[Expression] = child :: Nil + override def nullable: Boolean = false + override def dataType: DataType = ArrayType(DoubleType) + override def toString: String = s"computePartialStddev($child)" + override def newInstance(): ComputePartialStdFunction = + new ComputePartialStdFunction(child, this) +} + +case class ComputePartialStdFunction ( + expr: Expression, + base: AggregateExpression1 +) extends AggregateFunction1 { + def this() = this(null, null) // Required for serialization + + private val computeType = DoubleType + private val zero = Cast(Literal(0), computeType) + private var partialCount: Long = 0L + + // the mean of data processed so far + private val partialAvg: MutableLiteral = MutableLiteral(zero.eval(null), computeType) + + // update average based on this formula: + // avg = avg + (value - avg)/count + private def avgAddFunction (value: Literal): Expression = { + val delta = Subtract(Cast(value, computeType), partialAvg) + Add(partialAvg, Divide(delta, Cast(Literal(partialCount), computeType))) + } + + // the sum of squares of difference from mean + private val partialMk: MutableLiteral = MutableLiteral(zero.eval(null), computeType) + + // update sum of square of difference from mean based on following formula: + // Mk = Mk + (value - preAvg) * (value - updatedAvg) + private def mkAddFunction(value: Literal, prePartialAvg: MutableLiteral): Expression = { + val delta1 = Subtract(Cast(value, computeType), prePartialAvg) + val delta2 = Subtract(Cast(value, computeType), partialAvg) + Add(partialMk, Multiply(delta1, delta2)) + } + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input) + if (evaluatedExpr != null) { + val exprValue = Literal.create(evaluatedExpr, expr.dataType) + val prePartialAvg = partialAvg.copy() + partialCount += 1 + partialAvg.update(avgAddFunction(exprValue), input) + partialMk.update(mkAddFunction(exprValue, prePartialAvg), input) + } + } + + override def eval(input: InternalRow): Any = { + new GenericArrayData(Array(Cast(Literal(partialCount), computeType).eval(null), + partialAvg.eval(null), + partialMk.eval(null))) + } +} + +case class MergePartialStd( + child: Expression, + isSample: Boolean +) extends UnaryExpression with AggregateExpression1 { + def this() = this(null, false) // required for serialization + + override def children: Seq[Expression] = child:: Nil + override def nullable: Boolean = false + override def dataType: DataType = DoubleType + override def toString: String = s"MergePartialStd($child)" + override def newInstance(): MergePartialStdFunction = { + new MergePartialStdFunction(child, this, isSample) + } +} + +case class MergePartialStdFunction( + expr: Expression, + base: AggregateExpression1, + isSample: Boolean +) extends AggregateFunction1 { + def this() = this (null, null, false) // Required for serialization + + private val computeType = DoubleType + private val zero = Cast(Literal(0), computeType) + private val combineCount = MutableLiteral(zero.eval(null), computeType) + private val combineAvg = MutableLiteral(zero.eval(null), computeType) + private val combineMk = MutableLiteral(zero.eval(null), computeType) + + private def avgUpdateFunction(preCount: Expression, + partialCount: Expression, + partialAvg: Expression): Expression = { + Divide(Add(Multiply(combineAvg, preCount), + Multiply(partialAvg, partialCount)), + Add(preCount, partialCount)) + } + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input).asInstanceOf[ArrayData] + + if (evaluatedExpr != null) { + val exprValue = evaluatedExpr.toArray(computeType) + val (partialCount, partialAvg, partialMk) = + (Literal.create(exprValue(0), computeType), + Literal.create(exprValue(1), computeType), + Literal.create(exprValue(2), computeType)) + + if (Cast(partialCount, LongType).eval(null).asInstanceOf[Long] > 0) { + val preCount = combineCount.copy() + combineCount.update(Add(combineCount, partialCount), input) + + val preAvg = combineAvg.copy() + val avgDelta = Subtract(partialAvg, preAvg) + val mkDelta = Multiply(Multiply(avgDelta, avgDelta), + Divide(Multiply(preCount, partialCount), + combineCount)) + + // update average based on following formula + // (combineAvg * preCount + partialAvg * partialCount) / (preCount + partialCount) + combineAvg.update(avgUpdateFunction(preCount, partialCount, partialAvg), input) + + // update sum of square differences from mean based on following formula + // (combineMk + partialMk + (avgDelta * avgDelta) * (preCount * partialCount/combineCount) + combineMk.update(Add(combineMk, Add(partialMk, mkDelta)), input) + } + } + } + + override def eval(input: InternalRow): Any = { + val count: Long = Cast(combineCount, LongType).eval(null).asInstanceOf[Long] + + if (count == 0) null + else if (count < 2) zero.eval(null) + else { + // when total count > 2 + // stddev_samp = sqrt (combineMk/(combineCount -1)) + // stddev_pop = sqrt (combineMk/combineCount) + val varCol = { + if (isSample) { + Divide(combineMk, Cast(Literal(count - 1), computeType)) + } + else { + Divide(combineMk, Cast(Literal(count), computeType)) + } + } + Sqrt(varCol).eval(null) + } + } +} + +case class StddevFunction( + expr: Expression, + base: AggregateExpression1, + isSample: Boolean +) extends AggregateFunction1 { + + def this() = this(null, null, false) // Required for serialization + + private val computeType = DoubleType + private var curCount: Long = 0L + private val zero = Cast(Literal(0), computeType) + private val curAvg = MutableLiteral(zero.eval(null), computeType) + private val curMk = MutableLiteral(zero.eval(null), computeType) + + private def curAvgAddFunction(value: Literal): Expression = { + val delta = Subtract(Cast(value, computeType), curAvg) + Add(curAvg, Divide(delta, Cast(Literal(curCount), computeType))) + } + private def curMkAddFunction(value: Literal, preAvg: MutableLiteral): Expression = { + val delta1 = Subtract(Cast(value, computeType), preAvg) + val delta2 = Subtract(Cast(value, computeType), curAvg) + Add(curMk, Multiply(delta1, delta2)) + } + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input) + if (evaluatedExpr != null) { + val preAvg: MutableLiteral = curAvg.copy() + val exprValue = Literal.create(evaluatedExpr, expr.dataType) + curCount += 1L + curAvg.update(curAvgAddFunction(exprValue), input) + curMk.update(curMkAddFunction(exprValue, preAvg), input) + } + } + + override def eval(input: InternalRow): Any = { + if (curCount == 0) null + else if (curCount < 2) zero.eval(null) + else { + // when total count > 2, + // stddev_samp = sqrt(curMk/(curCount - 1)) + // stddev_pop = sqrt(curMk/curCount) + val varCol = { + if (isSample) { + Divide(curMk, Cast(Literal(curCount - 1), computeType)) + } + else { + Divide(curMk, Cast(Literal(curCount), computeType)) + } + } + Sqrt(varCol).eval(null) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 791c10c3d7ce7..1a687b2374f14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1288,15 +1288,11 @@ class DataFrame private[sql]( @scala.annotation.varargs def describe(cols: String*): DataFrame = { - // TODO: Add stddev as an expression, and remove it from here. - def stddevExpr(expr: Expression): Expression = - Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr)))) - // The list of summary statistics to compute, in the form of expressions. val statistics = List[(String, Expression => Expression)]( "count" -> Count, "mean" -> Average, - "stddev" -> stddevExpr, + "stddev" -> Stddev, "min" -> Min, "max" -> Max) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index ee31d83cce42c..102b802ad0a0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -124,6 +124,9 @@ class GroupedData protected[sql]( case "avg" | "average" | "mean" => Average case "max" => Max case "min" => Min + case "stddev" => Stddev + case "stddev_pop" => StddevPop + case "stddev_samp" => StddevSamp case "sum" => Sum case "count" | "size" => // Turn count(*) into count(1) @@ -283,6 +286,42 @@ class GroupedData protected[sql]( aggregateNumericColumns(colNames : _*)(Min) } + /** + * Compute the sample standard deviation for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the stddev for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def stddev(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(Stddev) + } + + /** + * Compute the population standard deviation for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the stddev for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def stddev_pop(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(StddevPop) + } + + /** + * Compute the sample standard deviation for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the stddev for them. + * + * @since 1.6.0 + */ + @scala.annotation.varargs + def stddev_samp(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(StddevSamp) + } + /** * Compute the sum for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 435e6319a64c4..60d9c509104d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -294,6 +294,33 @@ object functions { */ def min(columnName: String): Column = min(Column(columnName)) + /** + * Aggregate function: returns the unbiased sample standard deviation + * of the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev(e: Column): Column = Stddev(e.expr) + + /** + * Aggregate function: returns the population standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_pop(e: Column): Column = StddevPop(e.expr) + + /** + * Aggregate function: returns the unbiased sample standard deviation of + * the expression in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ + def stddev_samp(e: Column): Column = StddevSamp(e.expr) + /** * Aggregate function: returns the sum of all values in the expression. * diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index d981ce947f435..5f9abd4999ce0 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -90,6 +90,7 @@ public void testVarargMethods() { df.groupBy().mean("key"); df.groupBy().max("key"); df.groupBy().min("key"); + df.groupBy().stddev("key"); df.groupBy().sum("key"); // Varargs in column expressions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index c0950b09b14ad..f5ef9ffd7f4f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -175,6 +175,39 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(0, null)) } + test("stddev") { + val testData2ADev = math.sqrt(4/5.0) + + checkAnswer( + testData2.agg(stddev('a)), + Row(testData2ADev)) + + checkAnswer( + testData2.agg(stddev_pop('a)), + Row(math.sqrt(4/6.0))) + + checkAnswer( + testData2.agg(stddev_samp('a)), + Row(testData2ADev)) + } + + test("zero stddev") { + val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") + assert(emptyTableData.count() == 0) + + checkAnswer( + emptyTableData.agg(stddev('a)), + Row(null)) + + checkAnswer( + emptyTableData.agg(stddev_pop('a)), + Row(null)) + + checkAnswer( + emptyTableData.agg(stddev_samp('a)), + Row(null)) + } + test("zero sum") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index dbed4fc247140..c167999af580e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -436,7 +436,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val describeResult = Seq( Row("count", "4", "4"), Row("mean", "33.0", "178.0"), - Row("stddev", "16.583123951777", "10.0"), + Row("stddev", "19.148542155126762", "11.547005383792516"), Row("min", "16", "164"), Row("max", "60", "192")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 664b7a1512faf..962b100b532c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -328,6 +328,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { testCodeGen( "SELECT min(key) FROM testData3x", Row(1) :: Nil) + // STDDEV + testCodeGen( + "SELECT a, stddev(b), stddev_pop(b) FROM testData2 GROUP BY a", + (1 to 3).map(i => Row(i, math.sqrt(0.5), math.sqrt(0.25)))) + testCodeGen( + "SELECT stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2", + Row(math.sqrt(1.5 / 5), math.sqrt(1.5 / 6), math.sqrt(1.5 / 5)) :: Nil) // Some combinations. testCodeGen( """ @@ -348,8 +355,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(100, 1, 50.5, 300, 100) :: Nil) // Aggregate with Code generation handling all null values testCodeGen( - "SELECT sum('a'), avg('a'), count(null) FROM testData", - Row(null, null, 0) :: Nil) + "SELECT sum('a'), avg('a'), stddev('a'), count(null) FROM testData", + Row(null, null, null, 0) :: Nil) } finally { sqlContext.dropTempTable("testData3x") sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue) @@ -515,8 +522,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("aggregates with nulls") { checkAnswer( - sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"), - Row(1, 3, 2, 6, 3) + sql("SELECT MIN(a), MAX(a), AVG(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"), + Row(1, 3, 2, 1, 6, 3) ) } @@ -722,6 +729,33 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("stddev") { + checkAnswer( + sql("SELECT STDDEV(a) FROM testData2"), + Row(math.sqrt(4/5.0)) + ) + } + + test("stddev_pop") { + checkAnswer( + sql("SELECT STDDEV_POP(a) FROM testData2"), + Row(math.sqrt(4/6.0)) + ) + } + + test("stddev_samp") { + checkAnswer( + sql("SELECT STDDEV_SAMP(a) FROM testData2"), + Row(math.sqrt(4/5.0)) + ) + } + + test("stddev agg") { + checkAnswer( + sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"), + (1 to 3).map(i => Row(i, math.sqrt(1/2.0), math.sqrt(1/4.0), math.sqrt(1/2.0)))) + } + test("inner join where, one match per row") { checkAnswer( sql("SELECT * FROM upperCaseData JOIN lowerCaseData WHERE n = N"), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index b126ec455fc69..a73b1bd52c09f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -507,41 +507,6 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te }.getMessage assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) } - - // TODO: once we support Hive UDAF in the new interface, - // we can remove the following two tests. - withSQLConf("spark.sql.useAggregate2" -> "true") { - val errorMessage = intercept[AnalysisException] { - sqlContext.sql( - """ - |SELECT - | key, - | mydoublesum(value + 1.5 * key), - | stddev_samp(value) - |FROM agg1 - |GROUP BY key - """.stripMargin).collect() - }.getMessage - assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) - - // This will fall back to the old aggregate - val newAggregateOperators = sqlContext.sql( - """ - |SELECT - | key, - | sum(value + 1.5 * key), - | stddev_samp(value) - |FROM agg1 - |GROUP BY key - """.stripMargin).queryExecution.executedPlan.collect { - case agg: aggregate.SortBasedAggregate => agg - case agg: aggregate.TungstenAggregate => agg - } - val message = - "We should fallback to the old aggregation code path if " + - "there is any aggregate function that cannot be converted to the new interface." - assert(newAggregateOperators.isEmpty, message) - } } } From b3a7480ab0821ab38f710de96e3ac4a13f62dbca Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 12 Sep 2015 16:23:55 -0700 Subject: [PATCH 269/802] [SPARK-10330] Add Scalastyle rule to require use of SparkHadoopUtil JobContext methods This is a followup to #8499 which adds a Scalastyle rule to mandate the use of SparkHadoopUtil's JobContext accessor methods and fixes the existing violations. Author: Josh Rosen Closes #8521 from JoshRosen/SPARK-10330-part2. --- .../src/main/scala/org/apache/spark/SparkContext.scala | 6 +++--- .../org/apache/spark/deploy/SparkHadoopUtil.scala | 4 ++++ .../scala/org/apache/spark/rdd/PairRDDFunctions.scala | 8 +++++--- .../scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala | 2 +- core/src/test/scala/org/apache/spark/FileSuite.scala | 6 ++++-- .../org/apache/spark/examples/CassandraCQLTest.scala | 3 +++ .../org/apache/spark/examples/CassandraTest.scala | 2 ++ scalastyle-config.xml | 8 ++++++++ .../sql/execution/datasources/WriterContainer.scala | 8 ++++++-- .../sql/execution/datasources/json/JSONRelation.scala | 2 +- .../datasources/parquet/CatalystReadSupport.scala | 6 +++++- .../parquet/DirectParquetOutputCommitter.scala | 6 +++++- .../datasources/parquet/ParquetRelation.scala | 10 +++++++--- .../datasources/parquet/ParquetTypesConverter.scala | 6 +++++- .../org/apache/spark/sql/hive/orc/OrcRelation.scala | 4 ++-- 15 files changed, 61 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index cbfe8bf31c3d6..e27b3c4962221 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -858,7 +858,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Use setInputPaths so that wholeTextFiles aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updateConf = job.getConfiguration + val updateConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) new WholeTextFileRDD( this, classOf[WholeTextFileInputFormat], @@ -910,7 +910,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Use setInputPaths so that binaryFiles aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updateConf = job.getConfiguration + val updateConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) new BinaryFileRDD( this, classOf[StreamInputFormat], @@ -1092,7 +1092,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Use setInputPaths so that newAPIHadoopFile aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updatedConf = job.getConfiguration + val updatedConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf).setName(path) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index f7723ef5bde4c..a0b7365df900a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -192,7 +192,9 @@ class SparkHadoopUtil extends Logging { * while it's interface in Hadoop 2.+. */ def getConfigurationFromJobContext(context: JobContext): Configuration = { + // scalastyle:off jobconfig val method = context.getClass.getMethod("getConfiguration") + // scalastyle:on jobconfig method.invoke(context).asInstanceOf[Configuration] } @@ -204,7 +206,9 @@ class SparkHadoopUtil extends Logging { */ def getTaskAttemptIDFromTaskAttemptContext( context: MapReduceTaskAttemptContext): MapReduceTaskAttemptID = { + // scalastyle:off jobconfig val method = context.getClass.getMethod("getTaskAttemptID") + // scalastyle:on jobconfig method.invoke(context).asInstanceOf[MapReduceTaskAttemptID] } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index c59f0d4aa75a0..199d79b811d65 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -996,8 +996,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) job.setOutputKeyClass(keyClass) job.setOutputValueClass(valueClass) job.setOutputFormatClass(outputFormatClass) - job.getConfiguration.set("mapred.output.dir", path) - saveAsNewAPIHadoopDataset(job.getConfiguration) + val jobConfiguration = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + jobConfiguration.set("mapred.output.dir", path) + saveAsNewAPIHadoopDataset(jobConfiguration) } /** @@ -1064,7 +1065,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) val stageId = self.id - val wrappedConf = new SerializableConfiguration(job.getConfiguration) + val jobConfiguration = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val wrappedConf = new SerializableConfiguration(jobConfiguration) val outfmt = job.getOutputFormatClass val jobFormat = outfmt.newInstance diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala index 9babe56267e08..0228c54e0511c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -86,7 +86,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( if (isDriverSide) { initDriverSideJobFuncOpt.map(f => f(job)) } - job.getConfiguration + SparkHadoopUtil.get.getConfigurationFromJobContext(job) } private val jobTrackerId: String = { diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 418763f4e5ffa..fdb00aafc4a48 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark import java.io.{File, FileWriter} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.input.PortableDataStream import org.apache.spark.storage.StorageLevel @@ -506,8 +507,9 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { job.setOutputKeyClass(classOf[String]) job.setOutputValueClass(classOf[String]) job.setOutputFormatClass(classOf[NewTextOutputFormat[String, String]]) - job.getConfiguration.set("mapred.output.dir", tempDir.getPath + "/outputDataset_new") - randomRDD.saveAsNewAPIHadoopDataset(job.getConfiguration) + val jobConfig = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + jobConfig.set("mapred.output.dir", tempDir.getPath + "/outputDataset_new") + randomRDD.saveAsNewAPIHadoopDataset(jobConfig) assert(new File(tempDir.getPath + "/outputDataset_new/part-r-00000").exists() === true) } diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index fa07c1e5017cd..d1b9b8d398dd8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -16,6 +16,7 @@ */ // scalastyle:off println + // scalastyle:off jobcontext package org.apache.spark.examples import java.nio.ByteBuffer @@ -81,6 +82,7 @@ object CassandraCQLTest { val job = new Job() job.setInputFormatClass(classOf[CqlPagingInputFormat]) + val configuration = job.getConfiguration ConfigHelper.setInputInitialAddress(job.getConfiguration(), cHost) ConfigHelper.setInputRpcPort(job.getConfiguration(), cPort) ConfigHelper.setInputColumnFamily(job.getConfiguration(), KeySpace, InputColumnFamily) @@ -135,3 +137,4 @@ object CassandraCQLTest { } } // scalastyle:on println +// scalastyle:on jobcontext diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala index 2e56d24c60c33..1e679bfb55343 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -16,6 +16,7 @@ */ // scalastyle:off println +// scalastyle:off jobcontext package org.apache.spark.examples import java.nio.ByteBuffer @@ -130,6 +131,7 @@ object CassandraTest { } } // scalastyle:on println +// scalastyle:on jobcontext /* create keyspace casDemo; diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 68fdb4141cf27..64a0c71bbef2a 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -168,6 +168,14 @@ This file is divided into 3 sections: scala.collection.JavaConverters._ and use .asScala / .asJava methods + + + ^getConfiguration$|^getTaskAttemptID$ + Instead of calling .getConfiguration() or .getTaskAttemptID() directly, + use SparkHadoopUtil's getConfigurationFromJobContext() and getTaskAttemptIDFromTaskAttemptContext() methods. + + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 9a573db0c023a..f8ef674ed29c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -47,7 +47,8 @@ private[sql] abstract class BaseWriterContainer( protected val dataSchema = relation.dataSchema - protected val serializableConf = new SerializableConfiguration(job.getConfiguration) + protected val serializableConf = + new SerializableConfiguration(SparkHadoopUtil.get.getConfigurationFromJobContext(job)) // This UUID is used to avoid output file name collision between different appending write jobs. // These jobs may belong to different SparkContext instances. Concrete data source implementations @@ -89,7 +90,8 @@ private[sql] abstract class BaseWriterContainer( // This UUID is sent to executor side together with the serialized `Configuration` object within // the `Job` instance. `OutputWriters` on the executor side should use this UUID to generate // unique task output files. - job.getConfiguration.set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString) + SparkHadoopUtil.get.getConfigurationFromJobContext(job). + set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString) // Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor // clones the Configuration object passed in. If we initialize the TaskAttemptContext first, @@ -182,7 +184,9 @@ private[sql] abstract class BaseWriterContainer( private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = { this.jobId = SparkHadoopWriter.createJobID(new Date, jobId) this.taskId = new TaskID(this.jobId, true, splitId) + // scalastyle:off jobcontext this.taskAttemptId = new TaskAttemptID(taskId, attemptId) + // scalastyle:on jobcontext } private def setupConf(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 7a49157d9e72c..8ee0127c3bde8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -81,7 +81,7 @@ private[sql] class JSONRelation( private def createBaseRdd(inputPaths: Array[FileStatus]): RDD[String] = { val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val conf = job.getConfiguration + val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) val paths = inputPaths.map(_.getPath) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala index 5a8166fac5418..8c819f1a48cd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala @@ -72,7 +72,11 @@ private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with // Called before `prepareForRead()` when initializing Parquet record reader. override def init(context: InitContext): ReadContext = { - val conf = context.getConfiguration + val conf = { + // scalastyle:off jobcontext + context.getConfiguration + // scalastyle:on jobcontext + } // If the target file was written by Spark SQL, we should be able to find a serialized Catalyst // schema of this file from its metadata. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala index 2c6b914328b60..de1fd0166ac5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala @@ -53,7 +53,11 @@ private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: T override def setupTask(taskContext: TaskAttemptContext): Unit = {} override def commitJob(jobContext: JobContext) { - val configuration = ContextUtil.getConfiguration(jobContext) + val configuration = { + // scalastyle:off jobcontext + ContextUtil.getConfiguration(jobContext) + // scalastyle:on jobcontext + } val fileSystem = outputPath.getFileSystem(configuration) if (configuration.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, true)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index c6bbc392cad4c..953fcab126970 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -211,7 +211,11 @@ private[sql] class ParquetRelation( override def sizeInBytes: Long = metadataCache.dataStatuses.map(_.getLen).sum override def prepareJobForWrite(job: Job): OutputWriterFactory = { - val conf = ContextUtil.getConfiguration(job) + val conf = { + // scalastyle:off jobcontext + ContextUtil.getConfiguration(job) + // scalastyle:on jobcontext + } // SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible val committerClassname = conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) @@ -528,7 +532,7 @@ private[sql] object ParquetRelation extends Logging { assumeBinaryIsString: Boolean, assumeInt96IsTimestamp: Boolean, followParquetFormatSpec: Boolean)(job: Job): Unit = { - val conf = job.getConfiguration + val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) conf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[CatalystReadSupport].getName) // Try to push down filters when filter push-down is enabled. @@ -572,7 +576,7 @@ private[sql] object ParquetRelation extends Logging { FileInputFormat.setInputPaths(job, inputFiles.map(_.getPath): _*) } - overrideMinSplitSize(parquetBlockSize, job.getConfiguration) + overrideMinSplitSize(parquetBlockSize, SparkHadoopUtil.get.getConfigurationFromJobContext(job)) } private[parquet] def readSchema( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala index 142301fe87cb6..b647bb6116afa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala @@ -123,7 +123,11 @@ private[parquet] object ParquetTypesConverter extends Logging { throw new IllegalArgumentException("Unable to read Parquet metadata: path is null") } val job = new Job() - val conf = configuration.getOrElse(ContextUtil.getConfiguration(job)) + val conf = { + // scalastyle:off jobcontext + configuration.getOrElse(ContextUtil.getConfiguration(job)) + // scalastyle:on jobcontext + } val fs: FileSystem = origPath.getFileSystem(conf) if (fs == null) { throw new IllegalArgumentException(s"Incorrectly formatted Parquet metadata path $origPath") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 7e89109259955..d1f30e188eafb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -208,7 +208,7 @@ private[sql] class OrcRelation( } override def prepareJobForWrite(job: Job): OutputWriterFactory = { - job.getConfiguration match { + SparkHadoopUtil.get.getConfigurationFromJobContext(job) match { case conf: JobConf => conf.setOutputFormat(classOf[OrcOutputFormat]) case conf => @@ -289,7 +289,7 @@ private[orc] case class OrcTableScan( def execute(): RDD[InternalRow] = { val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val conf = job.getConfiguration + val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) // Tries to push down filters if ORC filter push-down is enabled if (sqlContext.conf.orcFilterPushDown) { From 1dc614b874badde0eee60def46fb47f608bc4759 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 13 Sep 2015 08:36:46 +0100 Subject: [PATCH 270/802] [SPARK-10222] [GRAPHX] [DOCS] More thoroughly deprecate Bagel in favor of GraphX Finish deprecating Bagel; remove reference to nonexistent example Author: Sean Owen Closes #8731 from srowen/SPARK-10222. --- .../src/main/scala/org/apache/spark/bagel/Bagel.scala | 6 ++++++ docs/bagel-programming-guide.md | 10 +--------- docs/index.md | 1 - pom.xml | 2 +- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala index 4e6b7686f771d..8399033ac61ec 100644 --- a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala +++ b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala @@ -22,6 +22,7 @@ import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") object Bagel extends Logging { val DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK @@ -270,18 +271,21 @@ object Bagel extends Logging { } } +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Combiner[M, C] { def createCombiner(msg: M): C def mergeMsg(combiner: C, msg: M): C def mergeCombiners(a: C, b: C): C } +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Aggregator[V, A] { def createAggregator(vert: V): A def mergeAggregators(a: A, b: A): A } /** Default combiner that simply appends messages together (i.e. performs no aggregation) */ +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializable { def createCombiner(msg: M): Array[M] = Array(msg) @@ -297,6 +301,7 @@ class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializab * Subclasses may store state along with each vertex and must * inherit from java.io.Serializable or scala.Serializable. */ +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Vertex { def active: Boolean } @@ -307,6 +312,7 @@ trait Vertex { * Subclasses may contain a payload to deliver to the target vertex * and must inherit from java.io.Serializable or scala.Serializable. */ +@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") trait Message[K] { def targetId: K } diff --git a/docs/bagel-programming-guide.md b/docs/bagel-programming-guide.md index c2fe6b0e286ce..347ca4a7af989 100644 --- a/docs/bagel-programming-guide.md +++ b/docs/bagel-programming-guide.md @@ -4,7 +4,7 @@ displayTitle: Bagel Programming Guide title: Bagel --- -**Bagel will soon be superseded by [GraphX](graphx-programming-guide.html); we recommend that new users try GraphX instead.** +**Bagel is deprecated, and superseded by [GraphX](graphx-programming-guide.html).** Bagel is a Spark implementation of Google's [Pregel](http://portal.acm.org/citation.cfm?id=1807184) graph processing framework. Bagel currently supports basic graph computation, combiners, and aggregators. @@ -157,11 +157,3 @@ trait Message[K] { def targetId: K } {% endhighlight %} - -# Where to Go from Here - -Two example jobs, PageRank and shortest path, are included in `examples/src/main/scala/org/apache/spark/examples/bagel`. You can run them by passing the class name to the `bin/run-example` script included in Spark; e.g.: - - ./bin/run-example org.apache.spark.examples.bagel.WikipediaPageRank - -Each example program prints usage help when run without any arguments. diff --git a/docs/index.md b/docs/index.md index d85cf12defefd..c0dc2b8d7412a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -90,7 +90,6 @@ options for deployment: * [Spark SQL and DataFrames](sql-programming-guide.html): support for structured data and relational queries * [MLlib](mllib-guide.html): built-in machine learning library * [GraphX](graphx-programming-guide.html): Spark's new API for graph processing - * [Bagel (Pregel on Spark)](bagel-programming-guide.html): older, simple graph processing model **API Docs:** diff --git a/pom.xml b/pom.xml index 88ebceca769e9..421357e141572 100644 --- a/pom.xml +++ b/pom.xml @@ -87,7 +87,7 @@ core - bagel + bagel graphx mllib tools From d81565465cc6d4f38b4ed78036cded630c700388 Mon Sep 17 00:00:00 2001 From: Bertrand Dechoux Date: Mon, 14 Sep 2015 09:18:46 +0100 Subject: [PATCH 271/802] [SPARK-9720] [ML] Identifiable types need UID in toString methods A few Identifiable types did override their toString method but without using the parent implementation. As a consequence, the uid was not present anymore in the toString result. It is the default behaviour. This patch is a quick fix. The question of enforcement is still up. No tests have been written to verify the toString method behaviour. That would be long to do because all types should be tested and not only those which have a regression now. It is possible to enforce the condition using the compiler by making the toString method final but that would introduce unwanted potential API breaking changes (see jira). Author: Bertrand Dechoux Closes #8062 from BertrandDechoux/SPARK-9720. --- .../spark/ml/classification/DecisionTreeClassifier.scala | 2 +- .../org/apache/spark/ml/classification/GBTClassifier.scala | 2 +- .../scala/org/apache/spark/ml/classification/NaiveBayes.scala | 2 +- .../spark/ml/classification/RandomForestClassifier.scala | 2 +- .../src/main/scala/org/apache/spark/ml/feature/RFormula.scala | 4 ++-- .../apache/spark/ml/regression/DecisionTreeRegressor.scala | 2 +- .../scala/org/apache/spark/ml/regression/GBTRegressor.scala | 2 +- .../apache/spark/ml/regression/RandomForestRegressor.scala | 2 +- 8 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 0a75d5d22280f..b8eb49f9bdb48 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -146,7 +146,7 @@ final class DecisionTreeClassificationModel private[ml] ( } override def toString: String = { - s"DecisionTreeClassificationModel of depth $depth with $numNodes nodes" + s"DecisionTreeClassificationModel (uid=$uid) of depth $depth with $numNodes nodes" } /** (private[ml]) Convert to a model in the old API */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 3073a2a61ce83..ad8683648b975 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -200,7 +200,7 @@ final class GBTClassificationModel( } override def toString: String = { - s"GBTClassificationModel with $numTrees trees" + s"GBTClassificationModel (uid=$uid) with $numTrees trees" } /** (private[ml]) Convert to a model in the old API */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 69cb88a7e6718..082ea1ffad58f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -198,7 +198,7 @@ class NaiveBayesModel private[ml] ( } override def toString: String = { - s"NaiveBayesModel with ${pi.size} classes" + s"NaiveBayesModel (uid=$uid) with ${pi.size} classes" } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 11a6d72468333..a6ebee1bb10af 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -193,7 +193,7 @@ final class RandomForestClassificationModel private[ml] ( } override def toString: String = { - s"RandomForestClassificationModel with $numTrees trees" + s"RandomForestClassificationModel (uid=$uid) with $numTrees trees" } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index a7fa50444209b..dcd6fe3c406a4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -129,7 +129,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R override def copy(extra: ParamMap): RFormula = defaultCopy(extra) - override def toString: String = s"RFormula(${get(formula)})" + override def toString: String = s"RFormula(${get(formula)}) (uid=$uid)" } /** @@ -171,7 +171,7 @@ class RFormulaModel private[feature]( override def copy(extra: ParamMap): RFormulaModel = copyValues( new RFormulaModel(uid, resolvedFormula, pipelineModel)) - override def toString: String = s"RFormulaModel(${resolvedFormula})" + override def toString: String = s"RFormulaModel(${resolvedFormula}) (uid=$uid)" private def transformLabel(dataset: DataFrame): DataFrame = { val labelName = resolvedFormula.label diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index a2bcd67401d08..d9a244bea28d2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -118,7 +118,7 @@ final class DecisionTreeRegressionModel private[ml] ( } override def toString: String = { - s"DecisionTreeRegressionModel of depth $depth with $numNodes nodes" + s"DecisionTreeRegressionModel (uid=$uid) of depth $depth with $numNodes nodes" } /** Convert to a model in the old API */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index b66e61f37dd5e..d841ecb9e58d6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -189,7 +189,7 @@ final class GBTRegressionModel( } override def toString: String = { - s"GBTRegressionModel with $numTrees trees" + s"GBTRegressionModel (uid=$uid) with $numTrees trees" } /** (private[ml]) Convert to a model in the old API */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 2f36da371f577..ddb7214416a69 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -155,7 +155,7 @@ final class RandomForestRegressionModel private[ml] ( } override def toString: String = { - s"RandomForestRegressionModel with $numTrees trees" + s"RandomForestRegressionModel (uid=$uid) with $numTrees trees" } /** From 32407bfd2bdbf84d65cacfa7554dae6a2332bc37 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 14 Sep 2015 11:51:39 -0700 Subject: [PATCH 272/802] [SPARK-9899] [SQL] log warning for direct output committer with speculation enabled This is a follow-up of https://github.com/apache/spark/pull/8317. When speculation is enabled, there may be multiply tasks writing to the same path. Generally it's OK as we will write to a temporary directory first and only one task can commit the temporary directory to target path. However, when we use direct output committer, tasks will write data to target path directly without temporary directory. This causes problems like corrupted data. Please see [PR comment](https://github.com/apache/spark/pull/8191#issuecomment-131598385) for more details. Unfortunately, we don't have a simple flag to tell if a output committer will write to temporary directory or not, so for safety, we have to disable any customized output committer when `speculation` is true. Author: Wenchen Fan Closes #8687 from cloud-fan/direct-committer. --- .../apache/spark/rdd/PairRDDFunctions.scala | 44 ++++++++++++++++--- .../hive/execution/InsertIntoHiveTable.scala | 17 ++++++- .../spark/sql/hive/hiveWriterContainers.scala | 1 - 3 files changed, 53 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 199d79b811d65..a981b63942e6d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -1018,6 +1018,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) /** * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class * supporting the key and value types K and V in this RDD. + * + * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do + * not use output committer that writes data directly. + * There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad + * result of using direct output committer with speculation enabled. */ def saveAsHadoopFile( path: String, @@ -1030,10 +1035,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val hadoopConf = conf hadoopConf.setOutputKeyClass(keyClass) hadoopConf.setOutputValueClass(valueClass) - // Doesn't work in Scala 2.9 due to what may be a generics bug - // TODO: Should we uncomment this for Scala 2.10? - // conf.setOutputFormat(outputFormatClass) - hadoopConf.set("mapred.output.format.class", outputFormatClass.getName) + conf.setOutputFormat(outputFormatClass) for (c <- codec) { hadoopConf.setCompressMapOutput(true) hadoopConf.set("mapred.output.compress", "true") @@ -1047,6 +1049,19 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) hadoopConf.setOutputCommitter(classOf[FileOutputCommitter]) } + // When speculation is on and output committer class name contains "Direct", we should warn + // users that they may loss data if they are using a direct output committer. + val speculationEnabled = self.conf.getBoolean("spark.speculation", false) + val outputCommitterClass = hadoopConf.get("mapred.output.committer.class", "") + if (speculationEnabled && outputCommitterClass.contains("Direct")) { + val warningMessage = + s"$outputCommitterClass may be an output committer that writes data directly to " + + "the final location. Because speculation is enabled, this output committer may " + + "cause data loss (see the case in SPARK-10063). If possible, please use a output " + + "committer that does not have this behavior (e.g. FileOutputCommitter)." + logWarning(warningMessage) + } + FileOutputFormat.setOutputPath(hadoopConf, SparkHadoopWriter.createPathFromString(path, hadoopConf)) saveAsHadoopDataset(hadoopConf) @@ -1057,6 +1072,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Configuration object for that storage system. The Conf should set an OutputFormat and any * output paths required (e.g. a table name to write to) in the same way as it would be * configured for a Hadoop MapReduce job. + * + * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do + * not use output committer that writes data directly. + * There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad + * result of using direct output committer with speculation enabled. */ def saveAsNewAPIHadoopDataset(conf: Configuration): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). @@ -1115,6 +1135,20 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0) val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) + + // When speculation is on and output committer class name contains "Direct", we should warn + // users that they may loss data if they are using a direct output committer. + val speculationEnabled = self.conf.getBoolean("spark.speculation", false) + val outputCommitterClass = jobCommitter.getClass.getSimpleName + if (speculationEnabled && outputCommitterClass.contains("Direct")) { + val warningMessage = + s"$outputCommitterClass may be an output committer that writes data directly to " + + "the final location. Because speculation is enabled, this output committer may " + + "cause data loss (see the case in SPARK-10063). If possible, please use a output " + + "committer that does not have this behavior (e.g. FileOutputCommitter)." + logWarning(warningMessage) + } + jobCommitter.setupJob(jobTaskContext) self.context.runJob(self, writeShard) jobCommitter.commitJob(jobTaskContext) @@ -1129,7 +1163,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) def saveAsHadoopDataset(conf: JobConf): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf - val wrappedConf = new SerializableConfiguration(hadoopConf) val outputFormatInstance = hadoopConf.getOutputFormat val keyClass = hadoopConf.getOutputKeyClass val valueClass = hadoopConf.getOutputValueClass @@ -1157,7 +1190,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.preSetup() val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => { - val config = wrappedConf.value // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 58f7fa640e8a9..0c700bdb370ac 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde2.Serializer import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.mapred.{FileOutputFormat, JobConf} +import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row @@ -62,7 +62,7 @@ case class InsertIntoHiveTable( def output: Seq[Attribute] = Seq.empty - def saveAsHiveFile( + private def saveAsHiveFile( rdd: RDD[InternalRow], valueClass: Class[_], fileSinkConf: FileSinkDesc, @@ -178,6 +178,19 @@ case class InsertIntoHiveTable( val jobConf = new JobConf(sc.hiveconf) val jobConfSer = new SerializableJobConf(jobConf) + // When speculation is on and output committer class name contains "Direct", we should warn + // users that they may loss data if they are using a direct output committer. + val speculationEnabled = sqlContext.sparkContext.conf.getBoolean("spark.speculation", false) + val outputCommitterClass = jobConf.get("mapred.output.committer.class", "") + if (speculationEnabled && outputCommitterClass.contains("Direct")) { + val warningMessage = + s"$outputCommitterClass may be an output committer that writes data directly to " + + "the final location. Because speculation is enabled, this output committer may " + + "cause data loss (see the case in SPARK-10063). If possible, please use a output " + + "committer that does not have this behavior (e.g. FileOutputCommitter)." + logWarning(warningMessage) + } + val writerContainer = if (numDynamicPartitions > 0) { val dynamicPartColNames = partitionColumnNames.takeRight(numDynamicPartitions) new SparkHiveDynamicPartitionWriterContainer(jobConf, fileSinkConf, dynamicPartColNames) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 29a6f08f40728..4ca8042d22367 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -32,7 +32,6 @@ import org.apache.hadoop.mapred._ import org.apache.hadoop.hive.common.FileUtils import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.sql.Row import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils From cf2821ef5fd9965eb6256e8e8b3f1e00c0788098 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Mon, 14 Sep 2015 12:06:23 -0700 Subject: [PATCH 273/802] [SPARK-10584] [DOC] [SQL] Documentation about spark.sql.hive.metastore.version is wrong. The default value of hive metastore version is 1.2.1 but the documentation says the value of `spark.sql.hive.metastore.version` is 0.13.1. Also, we cannot get the default value by `sqlContext.getConf("spark.sql.hive.metastore.version")`. Author: Kousuke Saruta Closes #8739 from sarutak/SPARK-10584. --- docs/sql-programming-guide.md | 2 +- .../scala/org/apache/spark/sql/hive/HiveContext.scala | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 6a1b0fbfa1eb3..a0b911d207243 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1687,7 +1687,7 @@ The following options can be used to configure the version of Hive that is used Property NameDefaultMeaning spark.sql.hive.metastore.version - 0.13.1 + 1.2.1 Version of the Hive metastore. Available options are 0.12.0 through 1.2.1. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 2e791cea96b41..d37ba5ddc2d80 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -111,8 +111,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { * this does not necessarily need to be the same version of Hive that is used internally by * Spark SQL for execution. */ - protected[hive] def hiveMetastoreVersion: String = - getConf(HIVE_METASTORE_VERSION, hiveExecutionVersion) + protected[hive] def hiveMetastoreVersion: String = getConf(HIVE_METASTORE_VERSION) /** * The location of the jars that should be used to instantiate the HiveMetastoreClient. This @@ -202,7 +201,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { "Builtin jars can only be used when hive execution version == hive metastore version. " + s"Execution: ${hiveExecutionVersion} != Metastore: ${hiveMetastoreVersion}. " + "Specify a vaild path to the correct hive jars using $HIVE_METASTORE_JARS " + - s"or change $HIVE_METASTORE_VERSION to $hiveExecutionVersion.") + s"or change ${HIVE_METASTORE_VERSION.key} to $hiveExecutionVersion.") } // We recursively find all jars in the class loader chain, @@ -606,7 +605,11 @@ private[hive] object HiveContext { /** The version of hive used internally by Spark SQL. */ val hiveExecutionVersion: String = "1.2.1" - val HIVE_METASTORE_VERSION: String = "spark.sql.hive.metastore.version" + val HIVE_METASTORE_VERSION = stringConf("spark.sql.hive.metastore.version", + defaultValue = Some(hiveExecutionVersion), + doc = "Version of the Hive metastore. Available options are " + + s"0.12.0 through $hiveExecutionVersion.") + val HIVE_METASTORE_JARS = stringConf("spark.sql.hive.metastore.jars", defaultValue = Some("builtin"), doc = s""" From ce6f3f163bc667cb5da9ab4331c8bad10cc0d701 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 14 Sep 2015 12:08:52 -0700 Subject: [PATCH 274/802] [SPARK-10194] [MLLIB] [PYSPARK] SGD algorithms need convergenceTol parameter in Python [SPARK-3382](https://issues.apache.org/jira/browse/SPARK-3382) added a ```convergenceTol``` parameter for GradientDescent-based methods in Scala. We need that parameter in Python; otherwise, Python users will not be able to adjust that behavior (or even reproduce behavior from previous releases since the default changed). Author: Yanbo Liang Closes #8457 from yanboliang/spark-10194. --- .../mllib/api/python/PythonMLLibAPI.scala | 20 +++++++++--- python/pyspark/mllib/classification.py | 17 +++++++--- python/pyspark/mllib/regression.py | 32 ++++++++++++------- 3 files changed, 48 insertions(+), 21 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index f585aacd452e0..69ce7f50709a1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -132,7 +132,8 @@ private[python] class PythonMLLibAPI extends Serializable { regParam: Double, regType: String, intercept: Boolean, - validateData: Boolean): JList[Object] = { + validateData: Boolean, + convergenceTol: Double): JList[Object] = { val lrAlg = new LinearRegressionWithSGD() lrAlg.setIntercept(intercept) .setValidateData(validateData) @@ -141,6 +142,7 @@ private[python] class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) + .setConvergenceTol(convergenceTol) lrAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( lrAlg, @@ -159,7 +161,8 @@ private[python] class PythonMLLibAPI extends Serializable { miniBatchFraction: Double, initialWeights: Vector, intercept: Boolean, - validateData: Boolean): JList[Object] = { + validateData: Boolean, + convergenceTol: Double): JList[Object] = { val lassoAlg = new LassoWithSGD() lassoAlg.setIntercept(intercept) .setValidateData(validateData) @@ -168,6 +171,7 @@ private[python] class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) + .setConvergenceTol(convergenceTol) trainRegressionModel( lassoAlg, data, @@ -185,7 +189,8 @@ private[python] class PythonMLLibAPI extends Serializable { miniBatchFraction: Double, initialWeights: Vector, intercept: Boolean, - validateData: Boolean): JList[Object] = { + validateData: Boolean, + convergenceTol: Double): JList[Object] = { val ridgeAlg = new RidgeRegressionWithSGD() ridgeAlg.setIntercept(intercept) .setValidateData(validateData) @@ -194,6 +199,7 @@ private[python] class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) + .setConvergenceTol(convergenceTol) trainRegressionModel( ridgeAlg, data, @@ -212,7 +218,8 @@ private[python] class PythonMLLibAPI extends Serializable { initialWeights: Vector, regType: String, intercept: Boolean, - validateData: Boolean): JList[Object] = { + validateData: Boolean, + convergenceTol: Double): JList[Object] = { val SVMAlg = new SVMWithSGD() SVMAlg.setIntercept(intercept) .setValidateData(validateData) @@ -221,6 +228,7 @@ private[python] class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) + .setConvergenceTol(convergenceTol) SVMAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( SVMAlg, @@ -240,7 +248,8 @@ private[python] class PythonMLLibAPI extends Serializable { regParam: Double, regType: String, intercept: Boolean, - validateData: Boolean): JList[Object] = { + validateData: Boolean, + convergenceTol: Double): JList[Object] = { val LogRegAlg = new LogisticRegressionWithSGD() LogRegAlg.setIntercept(intercept) .setValidateData(validateData) @@ -249,6 +258,7 @@ private[python] class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) + .setConvergenceTol(convergenceTol) LogRegAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( LogRegAlg, diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 8f27c446a66e8..cb4ee83678081 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -241,7 +241,7 @@ class LogisticRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None, regParam=0.01, regType="l2", intercept=False, - validateData=True): + validateData=True, convergenceTol=0.001): """ Train a logistic regression model on the given data. @@ -274,11 +274,13 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, :param validateData: Boolean parameter which indicates if the algorithm should validate data before training. (default: True) + :param convergenceTol: A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations), float(step), float(miniBatchFraction), i, float(regParam), regType, - bool(intercept), bool(validateData)) + bool(intercept), bool(validateData), float(convergenceTol)) return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights) @@ -439,7 +441,7 @@ class SVMWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, regType="l2", - intercept=False, validateData=True): + intercept=False, validateData=True, convergenceTol=0.001): """ Train a support vector machine on the given data. @@ -472,11 +474,13 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, :param validateData: Boolean parameter which indicates if the algorithm should validate data before training. (default: True) + :param convergenceTol: A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step), float(regParam), float(miniBatchFraction), i, regType, - bool(intercept), bool(validateData)) + bool(intercept), bool(validateData), float(convergenceTol)) return _regression_train_wrapper(train, SVMModel, data, initialWeights) @@ -600,12 +604,15 @@ class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm): :param miniBatchFraction: Fraction of data on which SGD is run for each iteration. :param regParam: L2 Regularization parameter. + :param convergenceTol: A condition which decides iteration termination. """ - def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.01): + def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.01, + convergenceTol=0.001): self.stepSize = stepSize self.numIterations = numIterations self.regParam = regParam self.miniBatchFraction = miniBatchFraction + self.convergenceTol = convergenceTol self._model = None super(StreamingLogisticRegressionWithSGD, self).__init__( model=self._model) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 41946e3674fbe..256b7537fef6b 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -28,7 +28,8 @@ 'LinearRegressionModel', 'LinearRegressionWithSGD', 'RidgeRegressionModel', 'RidgeRegressionWithSGD', 'LassoModel', 'LassoWithSGD', 'IsotonicRegressionModel', - 'IsotonicRegression'] + 'IsotonicRegression', 'StreamingLinearAlgorithm', + 'StreamingLinearRegressionWithSGD'] class LabeledPoint(object): @@ -202,7 +203,7 @@ class LinearRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None, regParam=0.0, regType=None, intercept=False, - validateData=True): + validateData=True, convergenceTol=0.001): """ Train a linear regression model using Stochastic Gradient Descent (SGD). @@ -244,11 +245,14 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, :param validateData: Boolean parameter which indicates if the algorithm should validate data before training. (default: True) + :param convergenceTol: A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations), float(step), float(miniBatchFraction), i, float(regParam), - regType, bool(intercept), bool(validateData)) + regType, bool(intercept), bool(validateData), + float(convergenceTol)) return _regression_train_wrapper(train, LinearRegressionModel, data, initialWeights) @@ -330,7 +334,7 @@ class LassoWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, intercept=False, - validateData=True): + validateData=True, convergenceTol=0.001): """ Train a regression model with L1-regularization using Stochastic Gradient Descent. @@ -362,11 +366,13 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, :param validateData: Boolean parameter which indicates if the algorithm should validate data before training. (default: True) + :param convergenceTol: A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step), float(regParam), float(miniBatchFraction), i, bool(intercept), - bool(validateData)) + bool(validateData), float(convergenceTol)) return _regression_train_wrapper(train, LassoModel, data, initialWeights) @@ -449,7 +455,7 @@ class RidgeRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, intercept=False, - validateData=True): + validateData=True, convergenceTol=0.001): """ Train a regression model with L2-regularization using Stochastic Gradient Descent. @@ -481,11 +487,13 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, :param validateData: Boolean parameter which indicates if the algorithm should validate data before training. (default: True) + :param convergenceTol: A condition which decides iteration termination. + (default: 0.001) """ def train(rdd, i): return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step), float(regParam), float(miniBatchFraction), i, bool(intercept), - bool(validateData)) + bool(validateData), float(convergenceTol)) return _regression_train_wrapper(train, RidgeRegressionModel, data, initialWeights) @@ -636,15 +644,17 @@ class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm): After training on a batch of data, the weights obtained at the end of training are used as initial weights for the next batch. - :param: stepSize Step size for each iteration of gradient descent. - :param: numIterations Total number of iterations run. - :param: miniBatchFraction Fraction of data on which SGD is run for each + :param stepSize: Step size for each iteration of gradient descent. + :param numIterations: Total number of iterations run. + :param miniBatchFraction: Fraction of data on which SGD is run for each iteration. + :param convergenceTol: A condition which decides iteration termination. """ - def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0): + def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, convergenceTol=0.001): self.stepSize = stepSize self.numIterations = numIterations self.miniBatchFraction = miniBatchFraction + self.convergenceTol = convergenceTol self._model = None super(StreamingLinearRegressionWithSGD, self).__init__( model=self._model) From 8a634e9bcc671167613fb575c6c0c054fb4b3479 Mon Sep 17 00:00:00 2001 From: Nick Pritchard Date: Mon, 14 Sep 2015 13:27:45 -0700 Subject: [PATCH 275/802] [SPARK-10573] [ML] IndexToString output schema should be StringType Fixes bug where IndexToString output schema was DoubleType. Correct me if I'm wrong, but it doesn't seem like the output needs to have any "ML Attribute" metadata. Author: Nick Pritchard Closes #8751 from pnpritchard/SPARK-10573. --- .../scala/org/apache/spark/ml/feature/StringIndexer.scala | 5 ++--- .../org/apache/spark/ml/feature/StringIndexerSuite.scala | 8 ++++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 3a4ab9a857648..2b1592930e77b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.Transformer import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashMap /** @@ -229,8 +229,7 @@ class IndexToString private[ml] ( val outputColName = $(outputCol) require(inputFields.forall(_.name != outputColName), s"Output column $outputColName already exists.") - val attr = NominalAttribute.defaultAttr.withName($(outputCol)) - val outputFields = inputFields :+ attr.toStructField() + val outputFields = inputFields :+ StructField($(outputCol), StringType) StructType(outputFields) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 05e05bdc64bb1..ddcdb5f4212be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.sql.types.{StringType, StructType, StructField, DoubleType} import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite @@ -165,4 +166,11 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(a === b) } } + + test("IndexToString.transformSchema (SPARK-10573)") { + val idxToStr = new IndexToString().setInputCol("input").setOutputCol("output") + val inSchema = StructType(Seq(StructField("input", DoubleType))) + val outSchema = idxToStr.transformSchema(inSchema) + assert(outSchema("output").dataType === StringType) + } } From 7e32387ae6303fd1cd32389d47df87170b841c67 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 14 Sep 2015 14:10:54 -0700 Subject: [PATCH 276/802] [SPARK-10522] [SQL] Nanoseconds of Timestamp in Parquet should be positive Or Hive can't read it back correctly. Thanks vanzin for report this. Author: Davies Liu Closes #8674 from davies/positive_nano. --- .../spark/sql/catalyst/util/DateTimeUtils.scala | 12 +++++++----- .../sql/catalyst/util/DateTimeUtilsSuite.scala | 17 ++++++++--------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index d652fce3fd9b6..687ca000d12bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -42,6 +42,7 @@ object DateTimeUtils { final val SECONDS_PER_DAY = 60 * 60 * 24L final val MICROS_PER_SECOND = 1000L * 1000L final val NANOS_PER_SECOND = MICROS_PER_SECOND * 1000L + final val MICROS_PER_DAY = MICROS_PER_SECOND * SECONDS_PER_DAY final val MILLIS_PER_DAY = SECONDS_PER_DAY * 1000L @@ -190,13 +191,14 @@ object DateTimeUtils { /** * Returns Julian day and nanoseconds in a day from the number of microseconds + * + * Note: support timestamp since 4717 BC (without negative nanoseconds, compatible with Hive). */ def toJulianDay(us: SQLTimestamp): (Int, Long) = { - val seconds = us / MICROS_PER_SECOND - val day = seconds / SECONDS_PER_DAY + JULIAN_DAY_OF_EPOCH - val secondsInDay = seconds % SECONDS_PER_DAY - val nanos = (us % MICROS_PER_SECOND) * 1000L - (day.toInt, secondsInDay * NANOS_PER_SECOND + nanos) + val julian_us = us + JULIAN_DAY_OF_EPOCH * MICROS_PER_DAY + val day = julian_us / MICROS_PER_DAY + val micros = julian_us % MICROS_PER_DAY + (day.toInt, micros * 1000L) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 1596bb79fa94b..6b9a11f0ff743 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -52,15 +52,14 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(ns === 0) assert(fromJulianDay(d, ns) == 0L) - val t = Timestamp.valueOf("2015-06-11 10:10:10.100") - val (d1, ns1) = toJulianDay(fromJavaTimestamp(t)) - val t1 = toJavaTimestamp(fromJulianDay(d1, ns1)) - assert(t.equals(t1)) - - val t2 = Timestamp.valueOf("2015-06-11 20:10:10.100") - val (d2, ns2) = toJulianDay(fromJavaTimestamp(t2)) - val t22 = toJavaTimestamp(fromJulianDay(d2, ns2)) - assert(t2.equals(t22)) + Seq(Timestamp.valueOf("2015-06-11 10:10:10.100"), + Timestamp.valueOf("2015-06-11 20:10:10.100"), + Timestamp.valueOf("1900-06-11 20:10:10.100")).foreach { t => + val (d, ns) = toJulianDay(fromJavaTimestamp(t)) + assert(ns > 0) + val t1 = toJavaTimestamp(fromJulianDay(d, ns)) + assert(t.equals(t1)) + } } test("SPARK-6785: java date conversion before and after epoch") { From 64f04154e3078ec7340da97e3c2b07cf24e89098 Mon Sep 17 00:00:00 2001 From: Edoardo Vacchi Date: Mon, 14 Sep 2015 14:56:04 -0700 Subject: [PATCH 277/802] [SPARK-6981] [SQL] Factor out SparkPlanner and QueryExecution from SQLContext Alternative to PR #6122; in this case the refactored out classes are replaced by inner classes with the same name for backwards binary compatibility * process in a lighter-weight, backwards-compatible way Author: Edoardo Vacchi Closes #6356 from evacchi/sqlctx-refactoring-lite. --- .../org/apache/spark/sql/DataFrame.scala | 4 +- .../org/apache/spark/sql/SQLContext.scala | 138 ++---------------- .../spark/sql/execution/QueryExecution.scala | 85 +++++++++++ .../spark/sql/execution/SQLExecution.scala | 2 +- .../spark/sql/execution/SparkPlanner.scala | 92 ++++++++++++ .../spark/sql/execution/SparkStrategies.scala | 2 +- 6 files changed, 195 insertions(+), 128 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 1a687b2374f14..3e61123c145cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, SQLExecution} +import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation @@ -114,7 +114,7 @@ private[sql] object DataFrame { @Experimental class DataFrame private[sql]( @transient val sqlContext: SQLContext, - @DeveloperApi @transient val queryExecution: SQLContext#QueryExecution) extends Serializable { + @DeveloperApi @transient val queryExecution: QueryExecution) extends Serializable { // Note for Spark contributors: if adding or updating any action in `DataFrame`, please make sure // you wrap it with `withNewExecutionId` if this actions doesn't call other action. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4e8414af50b44..e3fdd782e6ff6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -38,6 +38,10 @@ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} +import org.apache.spark.sql.execution.{Filter, _} +import org.apache.spark.sql.{execution => sparkexecution} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.sources._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} @@ -188,9 +192,11 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] def parseSql(sql: String): LogicalPlan = ddlParser.parse(sql, false) - protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) + protected[sql] def executeSql(sql: String): + org.apache.spark.sql.execution.QueryExecution = executePlan(parseSql(sql)) - protected[sql] def executePlan(plan: LogicalPlan) = new this.QueryExecution(plan) + protected[sql] def executePlan(plan: LogicalPlan) = + new sparkexecution.QueryExecution(this, plan) @transient protected[sql] val tlSession = new ThreadLocal[SQLSession]() { @@ -781,77 +787,11 @@ class SQLContext(@transient val sparkContext: SparkContext) }.toArray } - protected[sql] class SparkPlanner extends SparkStrategies { - val sparkContext: SparkContext = self.sparkContext - - val sqlContext: SQLContext = self - - def codegenEnabled: Boolean = self.conf.codegenEnabled - - def unsafeEnabled: Boolean = self.conf.unsafeEnabled - - def numPartitions: Int = self.conf.numShufflePartitions - - def strategies: Seq[Strategy] = - experimental.extraStrategies ++ ( - DataSourceStrategy :: - DDLStrategy :: - TakeOrderedAndProject :: - HashAggregation :: - Aggregation :: - LeftSemiJoin :: - EquiJoinSelection :: - InMemoryScans :: - BasicOperators :: - CartesianProduct :: - BroadcastNestedLoopJoin :: Nil) - - /** - * Used to build table scan operators where complex projection and filtering are done using - * separate physical operators. This function returns the given scan operator with Project and - * Filter nodes added only when needed. For example, a Project operator is only used when the - * final desired output requires complex expressions to be evaluated or when columns can be - * further eliminated out after filtering has been done. - * - * The `prunePushedDownFilters` parameter is used to remove those filters that can be optimized - * away by the filter pushdown optimization. - * - * The required attributes for both filtering and expression evaluation are passed to the - * provided `scanBuilder` function so that it can avoid unnecessary column materialization. - */ - def pruneFilterProject( - projectList: Seq[NamedExpression], - filterPredicates: Seq[Expression], - prunePushedDownFilters: Seq[Expression] => Seq[Expression], - scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = { - - val projectSet = AttributeSet(projectList.flatMap(_.references)) - val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) - val filterCondition = - prunePushedDownFilters(filterPredicates).reduceLeftOption(catalyst.expressions.And) - - // Right now we still use a projection even if the only evaluation is applying an alias - // to a column. Since this is a no-op, it could be avoided. However, using this - // optimization with the current implementation would change the output schema. - // TODO: Decouple final output schema from expression evaluation so this copy can be - // avoided safely. - - if (AttributeSet(projectList.map(_.toAttribute)) == projectSet && - filterSet.subsetOf(projectSet)) { - // When it is possible to just use column pruning to get the right projection and - // when the columns of this projection are enough to evaluate all filter conditions, - // just do a scan followed by a filter, with no extra project. - val scan = scanBuilder(projectList.asInstanceOf[Seq[Attribute]]) - filterCondition.map(Filter(_, scan)).getOrElse(scan) - } else { - val scan = scanBuilder((projectSet ++ filterSet).toSeq) - Project(projectList, filterCondition.map(Filter(_, scan)).getOrElse(scan)) - } - } - } + @deprecated("use org.apache.spark.sql.SparkPlanner", "1.6.0") + protected[sql] class SparkPlanner extends sparkexecution.SparkPlanner(this) @transient - protected[sql] val planner = new SparkPlanner + protected[sql] val planner: sparkexecution.SparkPlanner = new sparkexecution.SparkPlanner(this) @transient protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[InternalRow], 1) @@ -898,59 +838,9 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] lazy val conf: SQLConf = new SQLConf } - /** - * :: DeveloperApi :: - * The primary workflow for executing relational queries using Spark. Designed to allow easy - * access to the intermediate phases of query execution for developers. - */ - @DeveloperApi - protected[sql] class QueryExecution(val logical: LogicalPlan) { - def assertAnalyzed(): Unit = analyzer.checkAnalysis(analyzed) - - lazy val analyzed: LogicalPlan = analyzer.execute(logical) - lazy val withCachedData: LogicalPlan = { - assertAnalyzed() - cacheManager.useCachedData(analyzed) - } - lazy val optimizedPlan: LogicalPlan = optimizer.execute(withCachedData) - - // TODO: Don't just pick the first one... - lazy val sparkPlan: SparkPlan = { - SparkPlan.currentContext.set(self) - planner.plan(optimizedPlan).next() - } - // executedPlan should not be used to initialize any SparkPlan. It should be - // only used for execution. - lazy val executedPlan: SparkPlan = prepareForExecution.execute(sparkPlan) - - /** Internal version of the RDD. Avoids copies and has no schema */ - lazy val toRdd: RDD[InternalRow] = executedPlan.execute() - - protected def stringOrError[A](f: => A): String = - try f.toString catch { case e: Throwable => e.toString } - - def simpleString: String = - s"""== Physical Plan == - |${stringOrError(executedPlan)} - """.stripMargin.trim - - override def toString: String = { - def output = - analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}").mkString(", ") - - s"""== Parsed Logical Plan == - |${stringOrError(logical)} - |== Analyzed Logical Plan == - |${stringOrError(output)} - |${stringOrError(analyzed)} - |== Optimized Logical Plan == - |${stringOrError(optimizedPlan)} - |== Physical Plan == - |${stringOrError(executedPlan)} - |Code Generation: ${stringOrError(executedPlan.codegenEnabled)} - """.stripMargin.trim - } - } + @deprecated("use org.apache.spark.sql.QueryExecution", "1.6.0") + protected[sql] class QueryExecution(logical: LogicalPlan) + extends sparkexecution.QueryExecution(this, logical) /** * Parses the data type in our internal string representation. The data type string should diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala new file mode 100644 index 0000000000000..7bb4133a29059 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.{Experimental, DeveloperApi} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.{InternalRow, optimizer} +import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * :: DeveloperApi :: + * The primary workflow for executing relational queries using Spark. Designed to allow easy + * access to the intermediate phases of query execution for developers. + */ +@DeveloperApi +class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { + val analyzer = sqlContext.analyzer + val optimizer = sqlContext.optimizer + val planner = sqlContext.planner + val cacheManager = sqlContext.cacheManager + val prepareForExecution = sqlContext.prepareForExecution + + def assertAnalyzed(): Unit = analyzer.checkAnalysis(analyzed) + + lazy val analyzed: LogicalPlan = analyzer.execute(logical) + lazy val withCachedData: LogicalPlan = { + assertAnalyzed() + cacheManager.useCachedData(analyzed) + } + lazy val optimizedPlan: LogicalPlan = optimizer.execute(withCachedData) + + // TODO: Don't just pick the first one... + lazy val sparkPlan: SparkPlan = { + SparkPlan.currentContext.set(sqlContext) + planner.plan(optimizedPlan).next() + } + // executedPlan should not be used to initialize any SparkPlan. It should be + // only used for execution. + lazy val executedPlan: SparkPlan = prepareForExecution.execute(sparkPlan) + + /** Internal version of the RDD. Avoids copies and has no schema */ + lazy val toRdd: RDD[InternalRow] = executedPlan.execute() + + protected def stringOrError[A](f: => A): String = + try f.toString catch { case e: Throwable => e.toString } + + def simpleString: String = + s"""== Physical Plan == + |${stringOrError(executedPlan)} + """.stripMargin.trim + + + override def toString: String = { + def output = + analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}").mkString(", ") + + s"""== Parsed Logical Plan == + |${stringOrError(logical)} + |== Analyzed Logical Plan == + |${stringOrError(output)} + |${stringOrError(analyzed)} + |== Optimized Logical Plan == + |${stringOrError(optimizedPlan)} + |== Physical Plan == + |${stringOrError(executedPlan)} + |Code Generation: ${stringOrError(executedPlan.codegenEnabled)} + """.stripMargin.trim + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index cee58218a885b..1422e15549c94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -37,7 +37,7 @@ private[sql] object SQLExecution { * we can connect them with an execution. */ def withNewExecutionId[T]( - sqlContext: SQLContext, queryExecution: SQLContext#QueryExecution)(body: => T): T = { + sqlContext: SQLContext, queryExecution: QueryExecution)(body: => T): T = { val sc = sqlContext.sparkContext val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) if (oldExecutionId == null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala new file mode 100644 index 0000000000000..b346f43faebe2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.datasources.DataSourceStrategy + +@Experimental +class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { + val sparkContext: SparkContext = sqlContext.sparkContext + + def codegenEnabled: Boolean = sqlContext.conf.codegenEnabled + + def unsafeEnabled: Boolean = sqlContext.conf.unsafeEnabled + + def numPartitions: Int = sqlContext.conf.numShufflePartitions + + def strategies: Seq[Strategy] = + sqlContext.experimental.extraStrategies ++ ( + DataSourceStrategy :: + DDLStrategy :: + TakeOrderedAndProject :: + HashAggregation :: + Aggregation :: + LeftSemiJoin :: + EquiJoinSelection :: + InMemoryScans :: + BasicOperators :: + CartesianProduct :: + BroadcastNestedLoopJoin :: Nil) + + /** + * Used to build table scan operators where complex projection and filtering are done using + * separate physical operators. This function returns the given scan operator with Project and + * Filter nodes added only when needed. For example, a Project operator is only used when the + * final desired output requires complex expressions to be evaluated or when columns can be + * further eliminated out after filtering has been done. + * + * The `prunePushedDownFilters` parameter is used to remove those filters that can be optimized + * away by the filter pushdown optimization. + * + * The required attributes for both filtering and expression evaluation are passed to the + * provided `scanBuilder` function so that it can avoid unnecessary column materialization. + */ + def pruneFilterProject( + projectList: Seq[NamedExpression], + filterPredicates: Seq[Expression], + prunePushedDownFilters: Seq[Expression] => Seq[Expression], + scanBuilder: Seq[Attribute] => SparkPlan): SparkPlan = { + + val projectSet = AttributeSet(projectList.flatMap(_.references)) + val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) + val filterCondition = + prunePushedDownFilters(filterPredicates).reduceLeftOption(catalyst.expressions.And) + + // Right now we still use a projection even if the only evaluation is applying an alias + // to a column. Since this is a no-op, it could be avoided. However, using this + // optimization with the current implementation would change the output schema. + // TODO: Decouple final output schema from expression evaluation so this copy can be + // avoided safely. + + if (AttributeSet(projectList.map(_.toAttribute)) == projectSet && + filterSet.subsetOf(projectSet)) { + // When it is possible to just use column pruning to get the right projection and + // when the columns of this projection are enough to evaluate all filter conditions, + // just do a scan followed by a filter, with no extra project. + val scan = scanBuilder(projectList.asInstanceOf[Seq[Attribute]]) + filterCondition.map(Filter(_, scan)).getOrElse(scan) + } else { + val scan = scanBuilder((projectSet ++ filterSet).toSeq) + Project(projectList, filterCondition.map(Filter(_, scan)).getOrElse(scan)) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4572d5efc92bb..5e40d77689045 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.{SQLContext, Strategy, execution} private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { - self: SQLContext#SparkPlanner => + self: SparkPlanner => object LeftSemiJoin extends Strategy with PredicateHelper { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { From 217e4964444f4e07b894b1bca768a0cbbe799ea0 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 14 Sep 2015 15:00:27 -0700 Subject: [PATCH 278/802] [SPARK-9996] [SPARK-9997] [SQL] Add local expand and NestedLoopJoin operators This PR is in conflict with #8535 and #8573. Will update this one when they are merged. Author: zsxwing Closes #8642 from zsxwing/expand-nest-join. --- .../sql/execution/local/ExpandNode.scala | 60 +++++ .../spark/sql/execution/local/LocalNode.scala | 55 +++- .../execution/local/NestedLoopJoinNode.scala | 156 ++++++++++++ .../sql/execution/local/ExpandNodeSuite.scala | 51 ++++ .../execution/local/HashJoinNodeSuite.scala | 14 - .../sql/execution/local/LocalNodeTest.scala | 14 + .../local/NestedLoopJoinNodeSuite.scala | 239 ++++++++++++++++++ 7 files changed, 574 insertions(+), 15 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala new file mode 100644 index 0000000000000..2aff156d18b54 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala @@ -0,0 +1,60 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Projection} + +case class ExpandNode( + conf: SQLConf, + projections: Seq[Seq[Expression]], + output: Seq[Attribute], + child: LocalNode) extends UnaryLocalNode(conf) { + + assert(projections.size > 0) + + private[this] var result: InternalRow = _ + private[this] var idx: Int = _ + private[this] var input: InternalRow = _ + private[this] var groups: Array[Projection] = _ + + override def open(): Unit = { + child.open() + groups = projections.map(ee => newProjection(ee, child.output)).toArray + idx = groups.length + } + + override def next(): Boolean = { + if (idx >= groups.length) { + if (child.next()) { + input = child.fetch() + idx = 0 + } else { + return false + } + } + result = groups(idx)(input) + idx += 1 + true + } + + override def fetch(): InternalRow = result + + override def close(): Unit = child.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index e540ef8555eb6..9840080e16953 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -23,7 +23,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.{SQLConf, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.StructType @@ -69,6 +69,18 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging */ def close(): Unit + /** Specifies whether this operator outputs UnsafeRows */ + def outputsUnsafeRows: Boolean = false + + /** Specifies whether this operator is capable of processing UnsafeRows */ + def canProcessUnsafeRows: Boolean = false + + /** + * Specifies whether this operator is capable of processing Java-object-based Rows (i.e. rows + * that are not UnsafeRows). + */ + def canProcessSafeRows: Boolean = true + /** * Returns the content through the [[Iterator]] interface. */ @@ -91,6 +103,28 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging result } + protected def newProjection( + expressions: Seq[Expression], + inputSchema: Seq[Attribute]): Projection = { + log.debug( + s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") + if (codegenEnabled) { + try { + GenerateProjection.generate(expressions, inputSchema) + } catch { + case NonFatal(e) => + if (isTesting) { + throw e + } else { + log.error("Failed to generate projection, fallback to interpret", e) + new InterpretedProjection(expressions, inputSchema) + } + } + } else { + new InterpretedProjection(expressions, inputSchema) + } + } + protected def newMutableProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute]): () => MutableProjection = { @@ -113,6 +147,25 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging } } + protected def newPredicate( + expression: Expression, + inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { + if (codegenEnabled) { + try { + GeneratePredicate.generate(expression, inputSchema) + } catch { + case NonFatal(e) => + if (isTesting) { + throw e + } else { + log.error("Failed to generate predicate, fallback to interpreted", e) + InterpretedPredicate.create(expression, inputSchema) + } + } + } else { + InterpretedPredicate.create(expression, inputSchema) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala new file mode 100644 index 0000000000000..7321fc66b4dde --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{FullOuter, RightOuter, LeftOuter, JoinType} +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.util.collection.{BitSet, CompactBuffer} + +case class NestedLoopJoinNode( + conf: SQLConf, + left: LocalNode, + right: LocalNode, + buildSide: BuildSide, + joinType: JoinType, + condition: Option[Expression]) extends BinaryLocalNode(conf) { + + override def output: Seq[Attribute] = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case x => + throw new IllegalArgumentException( + s"NestedLoopJoin should not take $x as the JoinType") + } + } + + private[this] def genResultProjection: InternalRow => InternalRow = { + if (outputsUnsafeRows) { + UnsafeProjection.create(schema) + } else { + identity[InternalRow] + } + } + + private[this] var currentRow: InternalRow = _ + + private[this] var iterator: Iterator[InternalRow] = _ + + override def open(): Unit = { + val (streamed, build) = buildSide match { + case BuildRight => (left, right) + case BuildLeft => (right, left) + } + build.open() + val buildRelation = new CompactBuffer[InternalRow] + while (build.next()) { + buildRelation += build.fetch().copy() + } + build.close() + + val boundCondition = + newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) + + val leftNulls = new GenericMutableRow(left.output.size) + val rightNulls = new GenericMutableRow(right.output.size) + val joinedRow = new JoinedRow + val matchedBuildTuples = new BitSet(buildRelation.size) + val resultProj = genResultProjection + streamed.open() + + // streamedRowMatches also contains null rows if using outer join + val streamedRowMatches: Iterator[InternalRow] = streamed.asIterator.flatMap { streamedRow => + val matchedRows = new CompactBuffer[InternalRow] + + var i = 0 + var streamRowMatched = false + + // Scan the build relation to look for matches for each streamed row + while (i < buildRelation.size) { + val buildRow = buildRelation(i) + buildSide match { + case BuildRight => joinedRow(streamedRow, buildRow) + case BuildLeft => joinedRow(buildRow, streamedRow) + } + if (boundCondition(joinedRow)) { + matchedRows += resultProj(joinedRow).copy() + streamRowMatched = true + matchedBuildTuples.set(i) + } + i += 1 + } + + // If this row had no matches and we're using outer join, join it with the null rows + if (!streamRowMatched) { + (joinType, buildSide) match { + case (LeftOuter | FullOuter, BuildRight) => + matchedRows += resultProj(joinedRow(streamedRow, rightNulls)).copy() + case (RightOuter | FullOuter, BuildLeft) => + matchedRows += resultProj(joinedRow(leftNulls, streamedRow)).copy() + case _ => + } + } + + matchedRows.iterator + } + + // If we're using outer join, find rows on the build side that didn't match anything + // and join them with the null row + lazy val unmatchedBuildRows: Iterator[InternalRow] = { + var i = 0 + buildRelation.filter { row => + val r = !matchedBuildTuples.get(i) + i += 1 + r + }.iterator + } + iterator = (joinType, buildSide) match { + case (RightOuter | FullOuter, BuildRight) => + streamedRowMatches ++ + unmatchedBuildRows.map { buildRow => resultProj(joinedRow(leftNulls, buildRow)) } + case (LeftOuter | FullOuter, BuildLeft) => + streamedRowMatches ++ + unmatchedBuildRows.map { buildRow => resultProj(joinedRow(buildRow, rightNulls)) } + case _ => streamedRowMatches + } + } + + override def next(): Boolean = { + if (iterator.hasNext) { + currentRow = iterator.next() + true + } else { + false + } + } + + override def fetch(): InternalRow = currentRow + + override def close(): Unit = { + left.close() + right.close() + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala new file mode 100644 index 0000000000000..cfa7f3f6dcb97 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala @@ -0,0 +1,51 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +class ExpandNodeSuite extends LocalNodeTest { + + import testImplicits._ + + test("expand") { + val input = Seq((1, 1), (2, 2), (3, 3), (4, 4), (5, 5)).toDF("key", "value") + checkAnswer( + input, + node => + ExpandNode(conf, Seq( + Seq( + input.col("key") + input.col("value"), input.col("key") - input.col("value") + ).map(_.expr), + Seq( + input.col("key") * input.col("value"), input.col("key") / input.col("value") + ).map(_.expr) + ), node.output, node), + Seq( + (2, 0), + (1, 1), + (4, 0), + (4, 1), + (6, 0), + (9, 1), + (8, 0), + (16, 1), + (10, 0), + (25, 1) + ).toDF().collect() + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala index 43b6f06aead88..78d891351f4a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -24,20 +24,6 @@ class HashJoinNodeSuite extends LocalNodeTest { import testImplicits._ - private def wrapForUnsafe( - f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = { - if (conf.unsafeEnabled) { - (left: LocalNode, right: LocalNode) => { - val _left = ConvertToUnsafeNode(conf, left) - val _right = ConvertToUnsafeNode(conf, right) - val r = f(_left, _right) - ConvertToSafeNode(conf, r) - } - } else { - f - } - } - def joinSuite(suiteName: String, confPairs: (String, String)*): Unit = { test(s"$suiteName: inner join with one match per row") { withSQLConf(confPairs: _*) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala index b95d4ea7f8f2a..86dd28064cc6a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -27,6 +27,20 @@ class LocalNodeTest extends SparkFunSuite with SharedSQLContext { def conf: SQLConf = sqlContext.conf + protected def wrapForUnsafe( + f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = { + if (conf.unsafeEnabled) { + (left: LocalNode, right: LocalNode) => { + val _left = ConvertToUnsafeNode(conf, left) + val _right = ConvertToUnsafeNode(conf, right) + val r = f(_left, _right) + ConvertToSafeNode(conf, r) + } + } else { + f + } + } + /** * Runs the LocalNode and makes sure the answer matches the expected result. * @param input the input data to be used. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala new file mode 100644 index 0000000000000..b1ef26ba82f16 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala @@ -0,0 +1,239 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter} +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} + +class NestedLoopJoinNodeSuite extends LocalNodeTest { + + import testImplicits._ + + private def joinSuite( + suiteName: String, buildSide: BuildSide, confPairs: (String, String)*): Unit = { + test(s"$suiteName: left outer join") { + withSQLConf(confPairs: _*) { + checkAnswer2( + upperCaseData, + lowerCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + LeftOuter, + Some((upperCaseData.col("N") === lowerCaseData.col("n")).expr)) + ), + upperCaseData.join(lowerCaseData, $"n" === $"N", "left").collect()) + + checkAnswer2( + upperCaseData, + lowerCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + LeftOuter, + Some( + (upperCaseData.col("N") === lowerCaseData.col("n") && + lowerCaseData.col("n") > 1).expr)) + ), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > 1, "left").collect()) + + checkAnswer2( + upperCaseData, + lowerCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + LeftOuter, + Some( + (upperCaseData.col("N") === lowerCaseData.col("n") && + upperCaseData.col("N") > 1).expr)) + ), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > 1, "left").collect()) + + checkAnswer2( + upperCaseData, + lowerCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + LeftOuter, + Some( + (upperCaseData.col("N") === lowerCaseData.col("n") && + lowerCaseData.col("l") > upperCaseData.col("L")).expr)) + ), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"l" > $"L", "left").collect()) + } + } + + test(s"$suiteName: right outer join") { + withSQLConf(confPairs: _*) { + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + RightOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N", "right").collect()) + + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + RightOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N") && + lowerCaseData.col("n") > 1).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "right").collect()) + + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + RightOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N") && + upperCaseData.col("N") > 1).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "right").collect()) + + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + RightOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N") && + lowerCaseData.col("l") > upperCaseData.col("L")).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "right").collect()) + } + } + + test(s"$suiteName: full outer join") { + withSQLConf(confPairs: _*) { + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + FullOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N", "full").collect()) + + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + FullOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N") && + lowerCaseData.col("n") > 1).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "full").collect()) + + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + FullOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N") && + upperCaseData.col("N") > 1).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "full").collect()) + + checkAnswer2( + lowerCaseData, + upperCaseData, + wrapForUnsafe( + (node1, node2) => NestedLoopJoinNode( + conf, + node1, + node2, + buildSide, + FullOuter, + Some((lowerCaseData.col("n") === upperCaseData.col("N") && + lowerCaseData.col("l") > upperCaseData.col("L")).expr)) + ), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "full").collect()) + } + } + } + + joinSuite( + "general-build-left", + BuildLeft, + SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") + joinSuite( + "general-build-right", + BuildRight, + SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") + joinSuite( + "tungsten-build-left", + BuildLeft, + SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") + joinSuite( + "tungsten-build-right", + BuildRight, + SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") +} From 16b6d18613e150c7038c613992d80a7828413e66 Mon Sep 17 00:00:00 2001 From: Erick Tryzelaar Date: Mon, 14 Sep 2015 15:02:38 -0700 Subject: [PATCH 279/802] [SPARK-10594] [YARN] Remove reference to --num-executors, add --properties-file `ApplicationMaster` no longer has the `--num-executors` flag, and had an undocumented `--properties-file` configuration option. cc srowen Author: Erick Tryzelaar Closes #8754 from erickt/master. --- .../apache/spark/deploy/yarn/ApplicationMasterArguments.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index b08412414aa1c..17d9943c795e3 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -105,9 +105,9 @@ class ApplicationMasterArguments(val args: Array[String]) { | place on the PYTHONPATH for Python apps. | --args ARGS Arguments to be passed to your application's main class. | Multiple invocations are possible, each will be passed in order. - | --num-executors NUM Number of executors to start (Default: 2) | --executor-cores NUM Number of cores for the executors (Default: 1) | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G) + | --properties-file FILE Path to a custom Spark properties file. """.stripMargin) // scalastyle:on println System.exit(exitCode) From 4e2242bb41dda922573046c00c5142745632f95f Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 14 Sep 2015 15:03:51 -0700 Subject: [PATCH 280/802] [SPARK-10576] [BUILD] Move .java files out of src/main/scala Move .java files in `src/main/scala` to `src/main/java` root, except for `package-info.java` (to stay next to package.scala) Author: Sean Owen Closes #8736 from srowen/SPARK-10576. --- .../org/apache/spark/annotation/AlphaComponent.java | 0 .../{scala => java}/org/apache/spark/annotation/DeveloperApi.java | 0 .../{scala => java}/org/apache/spark/annotation/Experimental.java | 0 .../main/{scala => java}/org/apache/spark/annotation/Private.java | 0 .../{scala => java}/org/apache/spark/graphx/TripletFields.java | 0 .../org/apache/spark/graphx/impl/EdgeActiveness.java | 0 .../org/apache/spark/sql/types/SQLUserDefinedType.java | 0 .../org/apache/spark/streaming/StreamingContextState.java | 0 8 files changed, 0 insertions(+), 0 deletions(-) rename core/src/main/{scala => java}/org/apache/spark/annotation/AlphaComponent.java (100%) rename core/src/main/{scala => java}/org/apache/spark/annotation/DeveloperApi.java (100%) rename core/src/main/{scala => java}/org/apache/spark/annotation/Experimental.java (100%) rename core/src/main/{scala => java}/org/apache/spark/annotation/Private.java (100%) rename graphx/src/main/{scala => java}/org/apache/spark/graphx/TripletFields.java (100%) rename graphx/src/main/{scala => java}/org/apache/spark/graphx/impl/EdgeActiveness.java (100%) rename sql/catalyst/src/main/{scala => java}/org/apache/spark/sql/types/SQLUserDefinedType.java (100%) rename streaming/src/main/{scala => java}/org/apache/spark/streaming/StreamingContextState.java (100%) diff --git a/core/src/main/scala/org/apache/spark/annotation/AlphaComponent.java b/core/src/main/java/org/apache/spark/annotation/AlphaComponent.java similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/AlphaComponent.java rename to core/src/main/java/org/apache/spark/annotation/AlphaComponent.java diff --git a/core/src/main/scala/org/apache/spark/annotation/DeveloperApi.java b/core/src/main/java/org/apache/spark/annotation/DeveloperApi.java similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/DeveloperApi.java rename to core/src/main/java/org/apache/spark/annotation/DeveloperApi.java diff --git a/core/src/main/scala/org/apache/spark/annotation/Experimental.java b/core/src/main/java/org/apache/spark/annotation/Experimental.java similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/Experimental.java rename to core/src/main/java/org/apache/spark/annotation/Experimental.java diff --git a/core/src/main/scala/org/apache/spark/annotation/Private.java b/core/src/main/java/org/apache/spark/annotation/Private.java similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/Private.java rename to core/src/main/java/org/apache/spark/annotation/Private.java diff --git a/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java b/graphx/src/main/java/org/apache/spark/graphx/TripletFields.java similarity index 100% rename from graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java rename to graphx/src/main/java/org/apache/spark/graphx/TripletFields.java diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeActiveness.java b/graphx/src/main/java/org/apache/spark/graphx/impl/EdgeActiveness.java similarity index 100% rename from graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeActiveness.java rename to graphx/src/main/java/org/apache/spark/graphx/impl/EdgeActiveness.java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContextState.java b/streaming/src/main/java/org/apache/spark/streaming/StreamingContextState.java similarity index 100% rename from streaming/src/main/scala/org/apache/spark/streaming/StreamingContextState.java rename to streaming/src/main/java/org/apache/spark/streaming/StreamingContextState.java From ffbbc2c58b9bf1e2abc2ea797feada6821ab4de8 Mon Sep 17 00:00:00 2001 From: Tom Graves Date: Mon, 14 Sep 2015 15:05:19 -0700 Subject: [PATCH 281/802] [SPARK-10549] scala 2.11 spark on yarn with security - Repl doesn't work Make this lazy so that it can set the yarn mode before creating the securityManager. Author: Tom Graves Author: Thomas Graves Closes #8719 from tgravescs/SPARK-10549. --- .../scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index be31eb2eda546..627148df80c11 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -35,7 +35,8 @@ object Main extends Logging { s.processArguments(List("-Yrepl-class-based", "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", "-classpath", getAddedJars.mkString(File.pathSeparator)), true) - val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf)) + // the creation of SecurityManager has to be lazy so SPARK_YARN_MODE is set if needed + lazy val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf)) var sparkContext: SparkContext = _ var sqlContext: SQLContext = _ var interp = new SparkILoop // this is a public var because tests reset it. From fd1e8cddf2635c55fec2ac6e1f1c221c9685af0f Mon Sep 17 00:00:00 2001 From: Forest Fang Date: Mon, 14 Sep 2015 15:07:13 -0700 Subject: [PATCH 282/802] [SPARK-10543] [CORE] Peak Execution Memory Quantile should be Per-task Basis Read `PEAK_EXECUTION_MEMORY` using `update` to get per task partial value instead of cumulative value. I tested with this workload: ```scala val size = 1000 val repetitions = 10 val data = sc.parallelize(1 to size, 5).map(x => (util.Random.nextInt(size / repetitions),util.Random.nextDouble)).toDF("key", "value") val res = data.toDF.groupBy("key").agg(sum("value")).count ``` Before: ![image](https://cloud.githubusercontent.com/assets/4317392/9828197/07dd6874-58b8-11e5-9bd9-6ba927c38b26.png) After: ![image](https://cloud.githubusercontent.com/assets/4317392/9828151/a5ddff30-58b7-11e5-8d31-eda5dc4eae79.png) Tasks view: ![image](https://cloud.githubusercontent.com/assets/4317392/9828199/17dc2b84-58b8-11e5-92a8-be89ce4d29d1.png) cc andrewor14 I appreciate if you can give feedback on this since I think you introduced display of this metric. Author: Forest Fang Closes #8726 from saurfang/stagepage. --- .../org/apache/spark/ui/jobs/StagePage.scala | 2 +- .../org/apache/spark/ui/StagePageSuite.scala | 29 ++++++++++++++----- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 4adc6596ba21c..2b71f55b7bb4f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -368,7 +368,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val peakExecutionMemory = validTasks.map { case TaskUIData(info, _, _) => info.accumulables .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY } - .map { acc => acc.value.toLong } + .map { acc => acc.update.getOrElse("0").toLong } .getOrElse(0L) .toDouble } diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index 3388c6dca81f1..86699e7f56953 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -23,7 +23,7 @@ import scala.xml.Node import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} -import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite, Success} +import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.ui.jobs.{JobProgressListener, StagePage, StagesTab} @@ -47,6 +47,14 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { assert(html3.contains(targetString)) } + test("SPARK-10543: peak execution memory should be per-task rather than cumulative") { + val unsafeConf = "spark.sql.unsafe.enabled" + val conf = new SparkConf(false).set(unsafeConf, "true") + val html = renderStagePage(conf).toString().toLowerCase + // verify min/25/50/75/max show task value not cumulative values + assert(html.contains("10.0 b" * 5)) + } + /** * Render a stage page started with the given conf and return the HTML. * This also runs a dummy stage to populate the page with useful content. @@ -67,12 +75,19 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { // Simulate a stage in job progress listener val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details") - val taskInfo = new TaskInfo(0, 0, 0, 0, "0", "localhost", TaskLocality.ANY, false) - jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) - jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) - taskInfo.markSuccessful() - jobListener.onTaskEnd( - SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, TaskMetrics.empty)) + // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness + (1 to 2).foreach { + taskId => + val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false) + val peakExecutionMemory = 10 + taskInfo.accumulables += new AccumulableInfo(0, InternalAccumulator.PEAK_EXECUTION_MEMORY, + Some(peakExecutionMemory.toString), (peakExecutionMemory * taskId).toString, true) + jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) + jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) + taskInfo.markSuccessful() + jobListener.onTaskEnd( + SparkListenerTaskEnd(0, 0, "result", Success, taskInfo, TaskMetrics.empty)) + } jobListener.onStageCompleted(SparkListenerStageCompleted(stageInfo)) page.render(request) } From 7b6c856367b9c36348e80e83959150da9656c4dd Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 14 Sep 2015 15:09:43 -0700 Subject: [PATCH 283/802] [SPARK-10564] ThreadingSuite: assertion failures in threads don't fail the test (round 2) This is a follow-up patch to #8723. I missed one case there. Author: Andrew Or Closes #8727 from andrewor14/fix-threading-suite. --- .../org/apache/spark/ThreadingSuite.scala | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala index cda2b245526f7..a96a4ce201c21 100644 --- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -147,12 +147,12 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { }.start() } sem.acquire(2) + throwable.foreach { t => throw t } if (ThreadingSuiteState.failed.get()) { logError("Waited 1 second without seeing runningThreads = 4 (it was " + ThreadingSuiteState.runningThreads.get() + "); failing test") fail("One or more threads didn't see runningThreads = 4") } - throwable.foreach { t => throw t } } test("set local properties in different thread") { @@ -178,8 +178,8 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { threads.foreach(_.start()) sem.acquire(5) - assert(sc.getLocalProperty("test") === null) throwable.foreach { t => throw t } + assert(sc.getLocalProperty("test") === null) } test("set and get local properties in parent-children thread") { @@ -207,15 +207,16 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { threads.foreach(_.start()) sem.acquire(5) + throwable.foreach { t => throw t } assert(sc.getLocalProperty("test") === "parent") assert(sc.getLocalProperty("Foo") === null) - throwable.foreach { t => throw t } } test("mutations to local properties should not affect submitted jobs (SPARK-6629)") { val jobStarted = new Semaphore(0) val jobEnded = new Semaphore(0) @volatile var jobResult: JobResult = null + var throwable: Option[Throwable] = None sc = new SparkContext("local", "test") sc.setJobGroup("originalJobGroupId", "description") @@ -232,14 +233,19 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { // Create a new thread which will inherit the current thread's properties val thread = new Thread() { override def run(): Unit = { - assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "originalJobGroupId") - // Sleeps for a total of 10 seconds, but allows cancellation to interrupt the task try { - sc.parallelize(1 to 100).foreach { x => - Thread.sleep(100) + assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "originalJobGroupId") + // Sleeps for a total of 10 seconds, but allows cancellation to interrupt the task + try { + sc.parallelize(1 to 100).foreach { x => + Thread.sleep(100) + } + } catch { + case s: SparkException => // ignored so that we don't print noise in test logs } } catch { - case s: SparkException => // ignored so that we don't print noise in test logs + case t: Throwable => + throwable = Some(t) } } } @@ -252,6 +258,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { // modification of the properties object should not affect the properties of running jobs sc.cancelJobGroup("originalJobGroupId") jobEnded.tryAcquire(10, TimeUnit.SECONDS) + throwable.foreach { t => throw t } assert(jobResult.isInstanceOf[JobFailed]) } } From 1a0955250bb65cd6f5818ad60efb62ea4b45d18e Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 14 Sep 2015 21:47:40 -0400 Subject: [PATCH 284/802] [SPARK-9851] Support submitting map stages individually in DAGScheduler This patch adds support for submitting map stages in a DAG individually so that we can make downstream decisions after seeing statistics about their output, as part of SPARK-9850. I also added more comments to many of the key classes in DAGScheduler. By itself, the patch is not super useful except maybe to switch between a shuffle and broadcast join, but with the other subtasks of SPARK-9850 we'll be able to do more interesting decisions. The main entry point is SparkContext.submitMapStage, which lets you run a map stage and see stats about the map output sizes. Other stats could also be collected through accumulators. See AdaptiveSchedulingSuite for a short example. Author: Matei Zaharia Closes #8180 from mateiz/spark-9851. --- .../apache/spark/MapOutputStatistics.scala | 27 ++ .../org/apache/spark/MapOutputTracker.scala | 49 +++- .../scala/org/apache/spark/SparkContext.scala | 17 ++ .../apache/spark/scheduler/ActiveJob.scala | 34 ++- .../apache/spark/scheduler/DAGScheduler.scala | 251 +++++++++++++++--- .../spark/scheduler/DAGSchedulerEvent.scala | 10 + .../apache/spark/scheduler/ResultStage.scala | 17 +- .../spark/scheduler/ShuffleMapStage.scala | 13 +- .../org/apache/spark/scheduler/Stage.scala | 26 +- .../scala/org/apache/spark/FailureSuite.scala | 21 ++ .../scheduler/AdaptiveSchedulingSuite.scala | 65 +++++ .../spark/scheduler/DAGSchedulerSuite.scala | 243 ++++++++++++++++- 12 files changed, 710 insertions(+), 63 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/MapOutputStatistics.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/AdaptiveSchedulingSuite.scala diff --git a/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala new file mode 100644 index 0000000000000..f8a6f1d0d8cbb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/MapOutputStatistics.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * Holds statistics about the output sizes in a map stage. May become a DeveloperApi in the future. + * + * @param shuffleId ID of the shuffle + * @param bytesByPartitionId approximate number of output bytes for each map output partition + * (may be inexact due to use of compressed map statuses) + */ +private[spark] class MapOutputStatistics(val shuffleId: Int, val bytesByPartitionId: Array[Long]) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index a387592783850..94eb8daa85c53 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io._ +import java.util.Arrays import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} @@ -132,13 +133,43 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * describing the shuffle blocks that are stored at that block manager. */ def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, reduce $reduceId") - val startTime = System.currentTimeMillis + val statuses = getStatuses(shuffleId) + // Synchronize on the returned array because, on the driver, it gets mutated in place + statuses.synchronized { + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) + } + } + /** + * Return statistics about all of the outputs for a given shuffle. + */ + def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { + val statuses = getStatuses(dep.shuffleId) + // Synchronize on the returned array because, on the driver, it gets mutated in place + statuses.synchronized { + val totalSizes = new Array[Long](dep.partitioner.numPartitions) + for (s <- statuses) { + for (i <- 0 until totalSizes.length) { + totalSizes(i) += s.getSizeForBlock(i) + } + } + new MapOutputStatistics(dep.shuffleId, totalSizes) + } + } + + /** + * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize + * on this array when reading it, because on the driver, we may be changing it in place. + * + * (It would be nice to remove this restriction in the future.) + */ + private def getStatuses(shuffleId: Int): Array[MapStatus] = { val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") + val startTime = System.currentTimeMillis var fetchedStatuses: Array[MapStatus] = null fetching.synchronized { // Someone else is fetching it; wait for them to be done @@ -160,7 +191,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } if (fetchedStatuses == null) { - // We won the race to fetch the output locs; do so + // We won the race to fetch the statuses; do so logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) // This try-finally prevents hangs due to timeouts: try { @@ -175,22 +206,18 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } } - logDebug(s"Fetching map output location for shuffle $shuffleId, reduce $reduceId took " + + logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + s"${System.currentTimeMillis - startTime} ms") if (fetchedStatuses != null) { - fetchedStatuses.synchronized { - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) - } + return fetchedStatuses } else { logError("Missing all output locations for shuffle " + shuffleId) throw new MetadataFetchFailedException( - shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId) + shuffleId, -1, "Missing all output locations for shuffle " + shuffleId) } } else { - statuses.synchronized { - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) - } + return statuses } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e27b3c4962221..dee6091ce3caf 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1984,6 +1984,23 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli new SimpleFutureAction(waiter, resultFunc) } + /** + * Submit a map stage for execution. This is currently an internal API only, but might be + * promoted to DeveloperApi in the future. + */ + private[spark] def submitMapStage[K, V, C](dependency: ShuffleDependency[K, V, C]) + : SimpleFutureAction[MapOutputStatistics] = { + assertNotStopped() + val callSite = getCallSite() + var result: MapOutputStatistics = null + val waiter = dagScheduler.submitMapStage( + dependency, + (r: MapOutputStatistics) => { result = r }, + callSite, + localProperties.get) + new SimpleFutureAction[MapOutputStatistics](waiter, result) + } + /** * Cancel active jobs for the specified group. See [[org.apache.spark.SparkContext.setJobGroup]] * for more information. diff --git a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala index 50a69379412d2..a3d2db31301b3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala @@ -23,18 +23,42 @@ import org.apache.spark.TaskContext import org.apache.spark.util.CallSite /** - * Tracks information about an active job in the DAGScheduler. + * A running job in the DAGScheduler. Jobs can be of two types: a result job, which computes a + * ResultStage to execute an action, or a map-stage job, which computes the map outputs for a + * ShuffleMapStage before any downstream stages are submitted. The latter is used for adaptive + * query planning, to look at map output statistics before submitting later stages. We distinguish + * between these two types of jobs using the finalStage field of this class. + * + * Jobs are only tracked for "leaf" stages that clients directly submitted, through DAGScheduler's + * submitJob or submitMapStage methods. However, either type of job may cause the execution of + * other earlier stages (for RDDs in the DAG it depends on), and multiple jobs may share some of + * these previous stages. These dependencies are managed inside DAGScheduler. + * + * @param jobId A unique ID for this job. + * @param finalStage The stage that this job computes (either a ResultStage for an action or a + * ShuffleMapStage for submitMapStage). + * @param callSite Where this job was initiated in the user's program (shown on UI). + * @param listener A listener to notify if tasks in this job finish or the job fails. + * @param properties Scheduling properties attached to the job, such as fair scheduler pool name. */ private[spark] class ActiveJob( val jobId: Int, - val finalStage: ResultStage, - val func: (TaskContext, Iterator[_]) => _, - val partitions: Array[Int], + val finalStage: Stage, val callSite: CallSite, val listener: JobListener, val properties: Properties) { - val numPartitions = partitions.length + /** + * Number of partitions we need to compute for this job. Note that result stages may not need + * to compute all partitions in their target RDD, for actions like first() and lookup(). + */ + val numPartitions = finalStage match { + case r: ResultStage => r.partitions.length + case m: ShuffleMapStage => m.rdd.partitions.length + } + + /** Which partitions of the stage have finished */ val finished = Array.fill[Boolean](numPartitions)(false) + var numFinished = 0 } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 09e963f5cdf60..b4f90e8347894 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -45,17 +45,65 @@ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of * stages for each job, keeps track of which RDDs and stage outputs are materialized, and finds a * minimal schedule to run the job. It then submits stages as TaskSets to an underlying - * TaskScheduler implementation that runs them on the cluster. + * TaskScheduler implementation that runs them on the cluster. A TaskSet contains fully independent + * tasks that can run right away based on the data that's already on the cluster (e.g. map output + * files from previous stages), though it may fail if this data becomes unavailable. * - * In addition to coming up with a DAG of stages, this class also determines the preferred + * Spark stages are created by breaking the RDD graph at shuffle boundaries. RDD operations with + * "narrow" dependencies, like map() and filter(), are pipelined together into one set of tasks + * in each stage, but operations with shuffle dependencies require multiple stages (one to write a + * set of map output files, and another to read those files after a barrier). In the end, every + * stage will have only shuffle dependencies on other stages, and may compute multiple operations + * inside it. The actual pipelining of these operations happens in the RDD.compute() functions of + * various RDDs (MappedRDD, FilteredRDD, etc). + * + * In addition to coming up with a DAG of stages, the DAGScheduler also determines the preferred * locations to run each task on, based on the current cache status, and passes these to the * low-level TaskScheduler. Furthermore, it handles failures due to shuffle output files being * lost, in which case old stages may need to be resubmitted. Failures *within* a stage that are * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task * a small number of times before cancelling the whole stage. * + * When looking through this code, there are several key concepts: + * + * - Jobs (represented by [[ActiveJob]]) are the top-level work items submitted to the scheduler. + * For example, when the user calls an action, like count(), a job will be submitted through + * submitJob. Each Job may require the execution of multiple stages to build intermediate data. + * + * - Stages ([[Stage]]) are sets of tasks that compute intermediate results in jobs, where each + * task computes the same function on partitions of the same RDD. Stages are separated at shuffle + * boundaries, which introduce a barrier (where we must wait for the previous stage to finish to + * fetch outputs). There are two types of stages: [[ResultStage]], for the final stage that + * executes an action, and [[ShuffleMapStage]], which writes map output files for a shuffle. + * Stages are often shared across multiple jobs, if these jobs reuse the same RDDs. + * + * - Tasks are individual units of work, each sent to one machine. + * + * - Cache tracking: the DAGScheduler figures out which RDDs are cached to avoid recomputing them + * and likewise remembers which shuffle map stages have already produced output files to avoid + * redoing the map side of a shuffle. + * + * - Preferred locations: the DAGScheduler also computes where to run each task in a stage based + * on the preferred locations of its underlying RDDs, or the location of cached or shuffle data. + * + * - Cleanup: all data structures are cleared when the running jobs that depend on them finish, + * to prevent memory leaks in a long-running application. + * + * To recover from failures, the same stage might need to run multiple times, which are called + * "attempts". If the TaskScheduler reports that a task failed because a map output file from a + * previous stage was lost, the DAGScheduler resubmits that lost stage. This is detected through a + * CompletionEvent with FetchFailed, or an ExecutorLost event. The DAGScheduler will wait a small + * amount of time to see whether other nodes or tasks fail, then resubmit TaskSets for any lost + * stage(s) that compute the missing tasks. As part of this process, we might also have to create + * Stage objects for old (finished) stages where we previously cleaned up the Stage object. Since + * tasks from the old attempt of a stage could still be running, care must be taken to map any + * events received in the correct Stage object. + * * Here's a checklist to use when making or reviewing changes to this class: * + * - All data structures should be cleared when the jobs involving them end to avoid indefinite + * accumulation of state in long-running programs. + * * - When adding a new data structure, update `DAGSchedulerSuite.assertDataStructuresEmpty` to * include the new structure. This will help to catch memory leaks. */ @@ -295,12 +343,12 @@ class DAGScheduler( */ private def newResultStage( rdd: RDD[_], - numTasks: Int, + func: (TaskContext, Iterator[_]) => _, + partitions: Array[Int], jobId: Int, callSite: CallSite): ResultStage = { val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId) - val stage: ResultStage = new ResultStage(id, rdd, numTasks, parentStages, jobId, callSite) - + val stage = new ResultStage(id, rdd, func, partitions, parentStages, jobId, callSite) stageIdToStage(id) = stage updateJobIdStageIdMaps(jobId, stage) stage @@ -500,12 +548,25 @@ class DAGScheduler( jobIdToStageIds -= job.jobId jobIdToActiveJob -= job.jobId activeJobs -= job - job.finalStage.resultOfJob = None + job.finalStage match { + case r: ResultStage => + r.resultOfJob = None + case m: ShuffleMapStage => + m.mapStageJobs = m.mapStageJobs.filter(_ != job) + } } /** - * Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object + * Submit an action job to the scheduler and get a JobWaiter object back. The JobWaiter object * can be used to block until the the job finishes executing or can be used to cancel the job. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param partitions set of partitions to run on; some jobs may not want to compute on all + * partitions of the target RDD, e.g. for operations like first() + * @param callSite where in the user program this job was called + * @param resultHandler callback to pass each result to + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name */ def submitJob[T, U]( rdd: RDD[T], @@ -524,6 +585,7 @@ class DAGScheduler( val jobId = nextJobId.getAndIncrement() if (partitions.size == 0) { + // Return immediately if the job is running 0 tasks return new JobWaiter[U](this, jobId, 0, resultHandler) } @@ -536,6 +598,18 @@ class DAGScheduler( waiter } + /** + * Run an action job on the given RDD and pass all the results to the resultHandler function as + * they arrive. Throws an exception if the job fials, or returns normally if successful. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param partitions set of partitions to run on; some jobs may not want to compute on all + * partitions of the target RDD, e.g. for operations like first() + * @param callSite where in the user program this job was called + * @param resultHandler callback to pass each result to + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + */ def runJob[T, U]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, @@ -559,6 +633,17 @@ class DAGScheduler( } } + /** + * Run an approximate job on the given RDD and pass all the results to an ApproximateEvaluator + * as they arrive. Returns a partial result object from the evaluator. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param evaluator [[ApproximateEvaluator]] to receive the partial results + * @param callSite where in the user program this job was called + * @param timeout maximum time to wait for the job, in milliseconds + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + */ def runApproximateJob[T, U, R]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, @@ -575,6 +660,41 @@ class DAGScheduler( listener.awaitResult() // Will throw an exception if the job fails } + /** + * Submit a shuffle map stage to run independently and get a JobWaiter object back. The waiter + * can be used to block until the the job finishes executing or can be used to cancel the job. + * This method is used for adaptive query planning, to run map stages and look at statistics + * about their outputs before submitting downstream stages. + * + * @param dependency the ShuffleDependency to run a map stage for + * @param callback function called with the result of the job, which in this case will be a + * single MapOutputStatistics object showing how much data was produced for each partition + * @param callSite where in the user program this job was submitted + * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + */ + def submitMapStage[K, V, C]( + dependency: ShuffleDependency[K, V, C], + callback: MapOutputStatistics => Unit, + callSite: CallSite, + properties: Properties): JobWaiter[MapOutputStatistics] = { + + val rdd = dependency.rdd + val jobId = nextJobId.getAndIncrement() + if (rdd.partitions.length == 0) { + throw new SparkException("Can't run submitMapStage on RDD with 0 partitions") + } + + // We create a JobWaiter with only one "task", which will be marked as complete when the whole + // map stage has completed, and will be passed the MapOutputStatistics for that stage. + // This makes it easier to avoid race conditions between the user code and the map output + // tracker that might result if we told the user the stage had finished, but then they queries + // the map output tracker and some node failures had caused the output statistics to be lost. + val waiter = new JobWaiter(this, jobId, 1, (i: Int, r: MapOutputStatistics) => callback(r)) + eventProcessLoop.post(MapStageSubmitted( + jobId, dependency, callSite, waiter, SerializationUtils.clone(properties))) + waiter + } + /** * Cancel a job that is running or waiting in the queue. */ @@ -583,6 +703,9 @@ class DAGScheduler( eventProcessLoop.post(JobCancelled(jobId)) } + /** + * Cancel all jobs in the given job group ID. + */ def cancelJobGroup(groupId: String): Unit = { logInfo("Asked to cancel job group " + groupId) eventProcessLoop.post(JobGroupCancelled(groupId)) @@ -720,31 +843,77 @@ class DAGScheduler( try { // New stage creation may throw an exception if, for example, jobs are run on a // HadoopRDD whose underlying HDFS files have been deleted. - finalStage = newResultStage(finalRDD, partitions.length, jobId, callSite) + finalStage = newResultStage(finalRDD, func, partitions, jobId, callSite) } catch { case e: Exception => logWarning("Creating new stage failed due to exception - job: " + jobId, e) listener.jobFailed(e) return } - if (finalStage != null) { - val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) - clearCacheLocs() - logInfo("Got job %s (%s) with %d output partitions".format( - job.jobId, callSite.shortForm, partitions.length)) - logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")") - logInfo("Parents of final stage: " + finalStage.parents) - logInfo("Missing parents: " + getMissingParentStages(finalStage)) - val jobSubmissionTime = clock.getTimeMillis() - jobIdToActiveJob(jobId) = job - activeJobs += job - finalStage.resultOfJob = Some(job) - val stageIds = jobIdToStageIds(jobId).toArray - val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) - listenerBus.post( - SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) - submitStage(finalStage) + + val job = new ActiveJob(jobId, finalStage, callSite, listener, properties) + clearCacheLocs() + logInfo("Got job %s (%s) with %d output partitions".format( + job.jobId, callSite.shortForm, partitions.length)) + logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")") + logInfo("Parents of final stage: " + finalStage.parents) + logInfo("Missing parents: " + getMissingParentStages(finalStage)) + + val jobSubmissionTime = clock.getTimeMillis() + jobIdToActiveJob(jobId) = job + activeJobs += job + finalStage.resultOfJob = Some(job) + val stageIds = jobIdToStageIds(jobId).toArray + val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) + listenerBus.post( + SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) + submitStage(finalStage) + + submitWaitingStages() + } + + private[scheduler] def handleMapStageSubmitted(jobId: Int, + dependency: ShuffleDependency[_, _, _], + callSite: CallSite, + listener: JobListener, + properties: Properties) { + // Submitting this map stage might still require the creation of some parent stages, so make + // sure that happens. + var finalStage: ShuffleMapStage = null + try { + // New stage creation may throw an exception if, for example, jobs are run on a + // HadoopRDD whose underlying HDFS files have been deleted. + finalStage = getShuffleMapStage(dependency, jobId) + } catch { + case e: Exception => + logWarning("Creating new stage failed due to exception - job: " + jobId, e) + listener.jobFailed(e) + return + } + + val job = new ActiveJob(jobId, finalStage, callSite, listener, properties) + clearCacheLocs() + logInfo("Got map stage job %s (%s) with %d output partitions".format( + jobId, callSite.shortForm, dependency.rdd.partitions.size)) + logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")") + logInfo("Parents of final stage: " + finalStage.parents) + logInfo("Missing parents: " + getMissingParentStages(finalStage)) + + val jobSubmissionTime = clock.getTimeMillis() + jobIdToActiveJob(jobId) = job + activeJobs += job + finalStage.mapStageJobs = job :: finalStage.mapStageJobs + val stageIds = jobIdToStageIds(jobId).toArray + val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) + listenerBus.post( + SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) + submitStage(finalStage) + + // If the whole stage has already finished, tell the listener and remove it + if (!finalStage.outputLocs.contains(Nil)) { + markMapStageJobAsFinished(job, mapOutputTracker.getStatistics(dependency)) } + submitWaitingStages() } @@ -814,7 +983,7 @@ class DAGScheduler( case s: ResultStage => val job = s.resultOfJob.get partitionsToCompute.map { id => - val p = job.partitions(id) + val p = s.partitions(id) (id, getPreferredLocs(stage.rdd, p)) }.toMap } @@ -844,7 +1013,7 @@ class DAGScheduler( case stage: ShuffleMapStage => closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef).array() case stage: ResultStage => - closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func): AnyRef).array() + closureSerializer.serialize((stage.rdd, stage.func): AnyRef).array() } taskBinary = sc.broadcast(taskBinaryBytes) @@ -875,7 +1044,7 @@ class DAGScheduler( case stage: ResultStage => val job = stage.resultOfJob.get partitionsToCompute.map { id => - val p: Int = job.partitions(id) + val p: Int = stage.partitions(id) val part = stage.rdd.partitions(p) val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptId, @@ -1052,13 +1221,21 @@ class DAGScheduler( logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name + ") because some of its tasks had failed: " + shuffleStage.outputLocs.zipWithIndex.filter(_._1.isEmpty) - .map(_._2).mkString(", ")) + .map(_._2).mkString(", ")) submitStage(shuffleStage) + } else { + // Mark any map-stage jobs waiting on this stage as finished + if (shuffleStage.mapStageJobs.nonEmpty) { + val stats = mapOutputTracker.getStatistics(shuffleStage.shuffleDep) + for (job <- shuffleStage.mapStageJobs) { + markMapStageJobAsFinished(job, stats) + } + } } // Note: newly runnable stages will be submitted below when we submit waiting stages } - } + } case Resubmitted => logInfo("Resubmitted " + task + ", so marking it as still running") @@ -1412,6 +1589,17 @@ class DAGScheduler( Nil } + /** Mark a map stage job as finished with the given output stats, and report to its listener. */ + def markMapStageJobAsFinished(job: ActiveJob, stats: MapOutputStatistics): Unit = { + // In map stage jobs, we only create a single "task", which is to finish all of the stage + // (including reusing any previous map outputs, etc); so we just mark task 0 as done + job.finished(0) = true + job.numFinished += 1 + job.listener.taskSucceeded(0, stats) + cleanupStateForJobAndIndependentStages(job) + listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded)) + } + def stop() { logInfo("Stopping DAGScheduler") messageScheduler.shutdownNow() @@ -1445,6 +1633,9 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) => dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) + case MapStageSubmitted(jobId, dependency, callSite, listener, properties) => + dagScheduler.handleMapStageSubmitted(jobId, dependency, callSite, listener, properties) + case StageCancelled(stageId) => dagScheduler.handleStageCancellation(stageId) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index f72a52e85dc15..dda3b6cc7f960 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -35,6 +35,7 @@ import org.apache.spark.util.CallSite */ private[scheduler] sealed trait DAGSchedulerEvent +/** A result-yielding job was submitted on a target RDD */ private[scheduler] case class JobSubmitted( jobId: Int, finalRDD: RDD[_], @@ -45,6 +46,15 @@ private[scheduler] case class JobSubmitted( properties: Properties = null) extends DAGSchedulerEvent +/** A map stage as submitted to run as a separate job */ +private[scheduler] case class MapStageSubmitted( + jobId: Int, + dependency: ShuffleDependency[_, _, _], + callSite: CallSite, + listener: JobListener, + properties: Properties = null) + extends DAGSchedulerEvent + private[scheduler] case class StageCancelled(stageId: Int) extends DAGSchedulerEvent private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala index bf81b9aca4810..c0451da1f0247 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala @@ -17,23 +17,30 @@ package org.apache.spark.scheduler +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.util.CallSite /** - * The ResultStage represents the final stage in a job. + * ResultStages apply a function on some partitions of an RDD to compute the result of an action. + * The ResultStage object captures the function to execute, `func`, which will be applied to each + * partition, and the set of partition IDs, `partitions`. Some stages may not run on all partitions + * of the RDD, for actions like first() and lookup(). */ private[spark] class ResultStage( id: Int, rdd: RDD[_], - numTasks: Int, + val func: (TaskContext, Iterator[_]) => _, + val partitions: Array[Int], parents: List[Stage], firstJobId: Int, callSite: CallSite) - extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { + extends Stage(id, rdd, partitions.length, parents, firstJobId, callSite) { - // The active job for this result stage. Will be empty if the job has already finished - // (e.g., because the job was cancelled). + /** + * The active job for this result stage. Will be empty if the job has already finished + * (e.g., because the job was cancelled). + */ var resultOfJob: Option[ActiveJob] = None override def toString: String = "ResultStage " + id diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index 48d8d8e9c4b78..7d92960876403 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -23,7 +23,15 @@ import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.CallSite /** - * The ShuffleMapStage represents the intermediate stages in a job. + * ShuffleMapStages are intermediate stages in the execution DAG that produce data for a shuffle. + * They occur right before each shuffle operation, and might contain multiple pipelined operations + * before that (e.g. map and filter). When executed, they save map output files that can later be + * fetched by reduce tasks. The `shuffleDep` field describes the shuffle each stage is part of, + * and variables like `outputLocs` and `numAvailableOutputs` track how many map outputs are ready. + * + * ShuffleMapStages can also be submitted independently as jobs with DAGScheduler.submitMapStage. + * For such stages, the ActiveJobs that submitted them are tracked in `mapStageJobs`. Note that + * there can be multiple ActiveJobs trying to compute the same shuffle map stage. */ private[spark] class ShuffleMapStage( id: Int, @@ -37,6 +45,9 @@ private[spark] class ShuffleMapStage( override def toString: String = "ShuffleMapStage " + id + /** Running map-stage jobs that were submitted to execute this stage independently (if any) */ + var mapStageJobs: List[ActiveJob] = Nil + var numAvailableOutputs: Int = 0 def isAvailable: Boolean = numAvailableOutputs == numPartitions diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index c086535782c23..b37eccbd0f7b8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -24,27 +24,33 @@ import org.apache.spark.rdd.RDD import org.apache.spark.util.CallSite /** - * A stage is a set of independent tasks all computing the same function that need to run as part + * A stage is a set of parallel tasks all computing the same function that need to run as part * of a Spark job, where all the tasks have the same shuffle dependencies. Each DAG of tasks run * by the scheduler is split up into stages at the boundaries where shuffle occurs, and then the * DAGScheduler runs these stages in topological order. * * Each Stage can either be a shuffle map stage, in which case its tasks' results are input for - * another stage, or a result stage, in which case its tasks directly compute the action that - * initiated a job (e.g. count(), save(), etc). For shuffle map stages, we also track the nodes - * that each output partition is on. + * other stage(s), or a result stage, in which case its tasks directly compute a Spark action + * (e.g. count(), save(), etc) by running a function on an RDD. For shuffle map stages, we also + * track the nodes that each output partition is on. * * Each Stage also has a firstJobId, identifying the job that first submitted the stage. When FIFO * scheduling is used, this allows Stages from earlier jobs to be computed first or recovered * faster on failure. * - * The callSite provides a location in user code which relates to the stage. For a shuffle map - * stage, the callSite gives the user code that created the RDD being shuffled. For a result - * stage, the callSite gives the user code that executes the associated action (e.g. count()). - * - * A single stage can consist of multiple attempts. In that case, the latestInfo field will - * be updated for each attempt. + * Finally, a single stage can be re-executed in multiple attempts due to fault recovery. In that + * case, the Stage object will track multiple StageInfo objects to pass to listeners or the web UI. + * The latest one will be accessible through latestInfo. * + * @param id Unique stage ID + * @param rdd RDD that this stage runs on: for a shuffle map stage, it's the RDD we run map tasks + * on, while for a result stage, it's the target RDD that we ran an action on + * @param numTasks Total number of tasks in stage; result stages in particular may not need to + * compute all partitions, e.g. for first(), lookup(), and take(). + * @param parents List of stages that this stage depends on (through shuffle dependencies). + * @param firstJobId ID of the first job this stage was part of, for FIFO scheduling. + * @param callSite Location in the user program associated with this stage: either where the target + * RDD was created, for a shuffle map stage, or where the action for a result stage was called. */ private[scheduler] abstract class Stage( val id: Int, diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index aa50a49c50232..f58756e6f6179 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -217,6 +217,27 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { FailureSuiteState.clear() } + // Run a 3-task map stage where one task fails once. + test("failure in tasks in a submitMapStage") { + sc = new SparkContext("local[1,2]", "test") + val rdd = sc.makeRDD(1 to 3, 3).map { x => + FailureSuiteState.synchronized { + FailureSuiteState.tasksRun += 1 + if (x == 1 && FailureSuiteState.tasksFailed == 0) { + FailureSuiteState.tasksFailed += 1 + throw new Exception("Intentional task failure") + } + } + (x, x) + } + val dep = new ShuffleDependency[Int, Int, Int](rdd, new HashPartitioner(2)) + sc.submitMapStage(dep).get() + FailureSuiteState.synchronized { + assert(FailureSuiteState.tasksRun === 4) + } + FailureSuiteState.clear() + } + // TODO: Need to add tests with shuffle fetch failures. } diff --git a/core/src/test/scala/org/apache/spark/scheduler/AdaptiveSchedulingSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/AdaptiveSchedulingSuite.scala new file mode 100644 index 0000000000000..3fe28027c3c21 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/AdaptiveSchedulingSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.rdd.{ShuffledRDDPartition, RDD, ShuffledRDD} +import org.apache.spark._ + +object AdaptiveSchedulingSuiteState { + var tasksRun = 0 + + def clear(): Unit = { + tasksRun = 0 + } +} + +/** A special ShuffledRDD where we can pass a ShuffleDependency object to use */ +class CustomShuffledRDD[K, V, C](@transient dep: ShuffleDependency[K, V, C]) + extends RDD[(K, C)](dep.rdd.context, Seq(dep)) { + + override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = { + val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]] + SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) + .read() + .asInstanceOf[Iterator[(K, C)]] + } + + override def getPartitions: Array[Partition] = { + Array.tabulate[Partition](dep.partitioner.numPartitions)(i => new ShuffledRDDPartition(i)) + } +} + +class AdaptiveSchedulingSuite extends SparkFunSuite with LocalSparkContext { + test("simple use of submitMapStage") { + try { + sc = new SparkContext("local[1,2]", "test") + val rdd = sc.parallelize(1 to 3, 3).map { x => + AdaptiveSchedulingSuiteState.tasksRun += 1 + (x, x) + } + val dep = new ShuffleDependency[Int, Int, Int](rdd, new HashPartitioner(2)) + val shuffled = new CustomShuffledRDD[Int, Int, Int](dep) + sc.submitMapStage(dep).get() + assert(AdaptiveSchedulingSuiteState.tasksRun == 3) + assert(shuffled.collect().toSet == Set((1, 1), (2, 2), (3, 3))) + assert(AdaptiveSchedulingSuiteState.tasksRun == 3) + } finally { + AdaptiveSchedulingSuiteState.clear() + } + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 1b9ff740ff530..1c55f90ad9b44 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -152,6 +152,14 @@ class DAGSchedulerSuite override def jobFailed(exception: Exception) = { failure = exception } } + /** A simple helper class for creating custom JobListeners */ + class SimpleListener extends JobListener { + val results = new HashMap[Int, Any] + var failure: Exception = null + override def taskSucceeded(index: Int, result: Any): Unit = results.put(index, result) + override def jobFailed(exception: Exception): Unit = { failure = exception } + } + before { sc = new SparkContext("local", "DAGSchedulerSuite") sparkListener.submittedStageInfos.clear() @@ -229,7 +237,7 @@ class DAGSchedulerSuite } } - /** Sends the rdd to the scheduler for scheduling and returns the job id. */ + /** Submits a job to the scheduler and returns the job id. */ private def submit( rdd: RDD[_], partitions: Array[Int], @@ -240,6 +248,15 @@ class DAGSchedulerSuite jobId } + /** Submits a map stage to the scheduler and returns the job id. */ + private def submitMapStage( + shuffleDep: ShuffleDependency[_, _, _], + listener: JobListener = jobListener): Int = { + val jobId = scheduler.nextJobId.getAndIncrement() + runEvent(MapStageSubmitted(jobId, shuffleDep, CallSite("", ""), listener)) + jobId + } + /** Sends TaskSetFailed to the scheduler. */ private def failed(taskSet: TaskSet, message: String) { runEvent(TaskSetFailed(taskSet, message, None)) @@ -1313,6 +1330,230 @@ class DAGSchedulerSuite assert(stackTraceString.contains("org.scalatest.FunSuite")) } + test("simple map stage submission") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + + // Submit a map stage by itself + submitMapStage(shuffleDep) + assert(results.size === 0) // No results yet + completeShuffleMapStageSuccessfully(0, 0, 1) + assert(results.size === 1) + results.clear() + assertDataStructuresEmpty() + + // Submit a reduce job that depends on this map stage; it should directly do the reduce + submit(reduceRdd, Array(0)) + completeNextResultStageWithSuccess(2, 0) + assert(results === Map(0 -> 42)) + results.clear() + assertDataStructuresEmpty() + + // Check that if we submit the map stage again, no tasks run + submitMapStage(shuffleDep) + assert(results.size === 1) + assertDataStructuresEmpty() + } + + test("map stage submission with reduce stage also depending on the data") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + + // Submit the map stage by itself + submitMapStage(shuffleDep) + + // Submit a reduce job that depends on this map stage + submit(reduceRdd, Array(0)) + + // Complete tasks for the map stage + completeShuffleMapStageSuccessfully(0, 0, 1) + assert(results.size === 1) + results.clear() + + // Complete tasks for the reduce stage + completeNextResultStageWithSuccess(1, 0) + assert(results === Map(0 -> 42)) + results.clear() + assertDataStructuresEmpty() + + // Check that if we submit the map stage again, no tasks run + submitMapStage(shuffleDep) + assert(results.size === 1) + assertDataStructuresEmpty() + } + + test("map stage submission with fetch failure") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + + // Submit a map stage by itself + submitMapStage(shuffleDep) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) + assert(results.size === 1) + results.clear() + assertDataStructuresEmpty() + + // Submit a reduce job that depends on this map stage, but where one reduce will fail a fetch + submit(reduceRdd, Array(0, 1)) + complete(taskSets(1), Seq( + (Success, 42), + (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null))) + // Ask the scheduler to try it again; TaskSet 2 will rerun the map task that we couldn't fetch + // from, then TaskSet 3 will run the reduce stage + scheduler.resubmitFailedStages() + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.size)))) + complete(taskSets(3), Seq((Success, 43))) + assert(results === Map(0 -> 42, 1 -> 43)) + results.clear() + assertDataStructuresEmpty() + + // Run another reduce job without a failure; this should just work + submit(reduceRdd, Array(0, 1)) + complete(taskSets(4), Seq( + (Success, 44), + (Success, 45))) + assert(results === Map(0 -> 44, 1 -> 45)) + results.clear() + assertDataStructuresEmpty() + + // Resubmit the map stage; this should also just work + submitMapStage(shuffleDep) + assert(results.size === 1) + results.clear() + assertDataStructuresEmpty() + } + + /** + * In this test, we have three RDDs with shuffle dependencies, and we submit map stage jobs + * that are waiting on each one, as well as a reduce job on the last one. We test that all of + * these jobs complete even if there are some fetch failures in both shuffles. + */ + test("map stage submission with multiple shared stages and failures") { + val rdd1 = new MyRDD(sc, 2, Nil) + val dep1 = new ShuffleDependency(rdd1, new HashPartitioner(2)) + val rdd2 = new MyRDD(sc, 2, List(dep1)) + val dep2 = new ShuffleDependency(rdd2, new HashPartitioner(2)) + val rdd3 = new MyRDD(sc, 2, List(dep2)) + + val listener1 = new SimpleListener + val listener2 = new SimpleListener + val listener3 = new SimpleListener + + submitMapStage(dep1, listener1) + submitMapStage(dep2, listener2) + submit(rdd3, Array(0, 1), listener = listener3) + + // Complete the first stage + assert(taskSets(0).stageId === 0) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", rdd1.partitions.size)), + (Success, makeMapStatus("hostB", rdd1.partitions.size)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(listener1.results.size === 1) + + // When attempting the second stage, show a fetch failure + assert(taskSets(1).stageId === 1) + complete(taskSets(1), Seq( + (Success, makeMapStatus("hostA", rdd2.partitions.size)), + (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0, 0, "ignored"), null))) + scheduler.resubmitFailedStages() + assert(listener2.results.size === 0) // Second stage listener should not have a result yet + + // Stage 0 should now be running as task set 2; make its task succeed + assert(taskSets(2).stageId === 0) + complete(taskSets(2), Seq( + (Success, makeMapStatus("hostC", rdd2.partitions.size)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + assert(listener2.results.size === 0) // Second stage listener should still not have a result + + // Stage 1 should now be running as task set 3; make its first task succeed + assert(taskSets(3).stageId === 1) + complete(taskSets(3), Seq( + (Success, makeMapStatus("hostB", rdd2.partitions.size)), + (Success, makeMapStatus("hostD", rdd2.partitions.size)))) + assert(mapOutputTracker.getMapSizesByExecutorId(dep2.shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostD"))) + assert(listener2.results.size === 1) + + // Finally, the reduce job should be running as task set 4; make it see a fetch failure, + // then make it run again and succeed + assert(taskSets(4).stageId === 2) + complete(taskSets(4), Seq( + (Success, 52), + (FetchFailed(makeBlockManagerId("hostD"), dep2.shuffleId, 0, 0, "ignored"), null))) + scheduler.resubmitFailedStages() + + // TaskSet 5 will rerun stage 1's lost task, then TaskSet 6 will rerun stage 2 + assert(taskSets(5).stageId === 1) + complete(taskSets(5), Seq( + (Success, makeMapStatus("hostE", rdd2.partitions.size)))) + complete(taskSets(6), Seq( + (Success, 53))) + assert(listener3.results === Map(0 -> 52, 1 -> 53)) + assertDataStructuresEmpty() + } + + /** + * In this test, we run a map stage where one of the executors fails but we still receive a + * "zombie" complete message from that executor. We want to make sure the stage is not reported + * as done until all tasks have completed. + */ + test("map stage submission with executor failure late map task completions") { + val shuffleMapRdd = new MyRDD(sc, 3, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + + submitMapStage(shuffleDep) + + val oldTaskSet = taskSets(0) + runEvent(CompletionEvent(oldTaskSet.tasks(0), Success, makeMapStatus("hostA", 2), + null, createFakeTaskInfo(), null)) + assert(results.size === 0) // Map stage job should not be complete yet + + // Pretend host A was lost + val oldEpoch = mapOutputTracker.getEpoch + runEvent(ExecutorLost("exec-hostA")) + val newEpoch = mapOutputTracker.getEpoch + assert(newEpoch > oldEpoch) + + // Suppose we also get a completed event from task 1 on the same host; this should be ignored + runEvent(CompletionEvent(oldTaskSet.tasks(1), Success, makeMapStatus("hostA", 2), + null, createFakeTaskInfo(), null)) + assert(results.size === 0) // Map stage job should not be complete yet + + // A completion from another task should work because it's a non-failed host + runEvent(CompletionEvent(oldTaskSet.tasks(2), Success, makeMapStatus("hostB", 2), + null, createFakeTaskInfo(), null)) + assert(results.size === 0) // Map stage job should not be complete yet + + // Now complete tasks in the second task set + val newTaskSet = taskSets(1) + assert(newTaskSet.tasks.size === 2) // Both tasks 0 and 1 were on on hostA + runEvent(CompletionEvent(newTaskSet.tasks(0), Success, makeMapStatus("hostB", 2), + null, createFakeTaskInfo(), null)) + assert(results.size === 0) // Map stage job should not be complete yet + runEvent(CompletionEvent(newTaskSet.tasks(1), Success, makeMapStatus("hostB", 2), + null, createFakeTaskInfo(), null)) + assert(results.size === 1) // Map stage job should now finally be complete + assertDataStructuresEmpty() + + // Also test that a reduce stage using this shuffled data can immediately run + val reduceRDD = new MyRDD(sc, 2, List(shuffleDep)) + results.clear() + submit(reduceRDD, Array(0, 1)) + complete(taskSets(2), Seq((Success, 42), (Success, 43))) + assert(results === Map(0 -> 42, 1 -> 43)) + results.clear() + assertDataStructuresEmpty() + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. From 55204181004c105c7a3e8c31a099b37e48bfd953 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 14 Sep 2015 19:46:34 -0700 Subject: [PATCH 285/802] [SPARK-10542] [PYSPARK] fix serialize namedtuple Author: Davies Liu Closes #8707 from davies/fix_namedtuple. --- python/pyspark/cloudpickle.py | 15 ++++++++++++++- python/pyspark/serializers.py | 1 + python/pyspark/tests.py | 5 +++++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 3b647985801b7..95b3abc74244b 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -350,6 +350,11 @@ def save_global(self, obj, name=None, pack=struct.pack): if new_override: d['__new__'] = obj.__new__ + # workaround for namedtuple (hijacked by PySpark) + if getattr(obj, '_is_namedtuple_', False): + self.save_reduce(_load_namedtuple, (obj.__name__, obj._fields)) + return + self.save(_load_class) self.save_reduce(typ, (obj.__name__, obj.__bases__, {"__doc__": obj.__doc__}), obj=obj) d.pop('__doc__', None) @@ -382,7 +387,7 @@ def save_instancemethod(self, obj): self.save_reduce(types.MethodType, (obj.__func__, obj.__self__), obj=obj) else: self.save_reduce(types.MethodType, (obj.__func__, obj.__self__, obj.__self__.__class__), - obj=obj) + obj=obj) dispatch[types.MethodType] = save_instancemethod def save_inst(self, obj): @@ -744,6 +749,14 @@ def _load_class(cls, d): return cls +def _load_namedtuple(name, fields): + """ + Loads a class generated by namedtuple + """ + from collections import namedtuple + return namedtuple(name, fields) + + """Constructors for 3rd party libraries Note: These can never be renamed due to client compatibility issues""" diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 411b4dbf481f1..2a1326947f4f5 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -359,6 +359,7 @@ def _hack_namedtuple(cls): def __reduce__(self): return (_restore, (name, fields, tuple(self))) cls.__reduce__ = __reduce__ + cls._is_namedtuple_ = True return cls diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 8bfed074c9052..647504c32f156 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -218,6 +218,11 @@ def test_namedtuple(self): p2 = loads(dumps(p1, 2)) self.assertEqual(p1, p2) + from pyspark.cloudpickle import dumps + P2 = loads(dumps(P)) + p3 = P2(1, 3) + self.assertEqual(p1, p3) + def test_itemgetter(self): from operator import itemgetter ser = CloudPickleSerializer() From 4ae4d54794778042b2cc983e52757edac02412ab Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 14 Sep 2015 21:37:43 -0700 Subject: [PATCH 286/802] [SPARK-9793] [MLLIB] [PYSPARK] PySpark DenseVector, SparseVector implement __eq__ and __hash__ correctly PySpark DenseVector, SparseVector ```__eq__``` method should use semantics equality, and DenseVector can compared with SparseVector. Implement PySpark DenseVector, SparseVector ```__hash__``` method based on the first 16 entries. That will make PySpark Vector objects can be used in collections. Author: Yanbo Liang Closes #8166 from yanboliang/spark-9793. --- python/pyspark/mllib/linalg/__init__.py | 90 ++++++++++++++++++++----- python/pyspark/mllib/tests.py | 32 +++++++++ 2 files changed, 107 insertions(+), 15 deletions(-) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 334dc8e38bb8f..380f86e9b44f8 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -25,6 +25,7 @@ import sys import array +import struct if sys.version >= '3': basestring = str @@ -122,6 +123,13 @@ def _format_float_list(l): return [_format_float(x) for x in l] +def _double_to_long_bits(value): + if np.isnan(value): + value = float('nan') + # pack double into 64 bits, then unpack as long int + return struct.unpack('Q', struct.pack('d', value))[0] + + class VectorUDT(UserDefinedType): """ SQL user-defined type (UDT) for Vector. @@ -404,11 +412,31 @@ def __repr__(self): return "DenseVector([%s])" % (', '.join(_format_float(i) for i in self.array)) def __eq__(self, other): - return isinstance(other, DenseVector) and np.array_equal(self.array, other.array) + if isinstance(other, DenseVector): + return np.array_equal(self.array, other.array) + elif isinstance(other, SparseVector): + if len(self) != other.size: + return False + return Vectors._equals(list(xrange(len(self))), self.array, other.indices, other.values) + return False def __ne__(self, other): return not self == other + def __hash__(self): + size = len(self) + result = 31 + size + nnz = 0 + i = 0 + while i < size and nnz < 128: + if self.array[i] != 0: + result = 31 * result + i + bits = _double_to_long_bits(self.array[i]) + result = 31 * result + (bits ^ (bits >> 32)) + nnz += 1 + i += 1 + return result + def __getattr__(self, item): return getattr(self.array, item) @@ -704,20 +732,14 @@ def __repr__(self): return "SparseVector({0}, {{{1}}})".format(self.size, entries) def __eq__(self, other): - """ - Test SparseVectors for equality. - - >>> v1 = SparseVector(4, [(1, 1.0), (3, 5.5)]) - >>> v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) - >>> v1 == v2 - True - >>> v1 != v2 - False - """ - return (isinstance(other, self.__class__) - and other.size == self.size - and np.array_equal(other.indices, self.indices) - and np.array_equal(other.values, self.values)) + if isinstance(other, SparseVector): + return other.size == self.size and np.array_equal(other.indices, self.indices) \ + and np.array_equal(other.values, self.values) + elif isinstance(other, DenseVector): + if self.size != len(other): + return False + return Vectors._equals(self.indices, self.values, list(xrange(len(other))), other.array) + return False def __getitem__(self, index): inds = self.indices @@ -739,6 +761,19 @@ def __getitem__(self, index): def __ne__(self, other): return not self.__eq__(other) + def __hash__(self): + result = 31 + self.size + nnz = 0 + i = 0 + while i < len(self.values) and nnz < 128: + if self.values[i] != 0: + result = 31 * result + int(self.indices[i]) + bits = _double_to_long_bits(self.values[i]) + result = 31 * result + (bits ^ (bits >> 32)) + nnz += 1 + i += 1 + return result + class Vectors(object): @@ -841,6 +876,31 @@ def parse(s): def zeros(size): return DenseVector(np.zeros(size)) + @staticmethod + def _equals(v1_indices, v1_values, v2_indices, v2_values): + """ + Check equality between sparse/dense vectors, + v1_indices and v2_indices assume to be strictly increasing. + """ + v1_size = len(v1_values) + v2_size = len(v2_values) + k1 = 0 + k2 = 0 + all_equal = True + while all_equal: + while k1 < v1_size and v1_values[k1] == 0: + k1 += 1 + while k2 < v2_size and v2_values[k2] == 0: + k2 += 1 + + if k1 >= v1_size or k2 >= v2_size: + return k1 >= v1_size and k2 >= v2_size + + all_equal = v1_indices[k1] == v2_indices[k2] and v1_values[k1] == v2_values[k2] + k1 += 1 + k2 += 1 + return all_equal + class Matrix(object): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 5097c5e8ba4cd..636f9a06cab7b 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -194,6 +194,38 @@ def test_squared_distance(self): self.assertEquals(3.0, _squared_distance(sv, arr)) self.assertEquals(3.0, _squared_distance(sv, narr)) + def test_hash(self): + v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v4 = SparseVector(4, [(1, 1.0), (3, 2.5)]) + self.assertEquals(hash(v1), hash(v2)) + self.assertEquals(hash(v1), hash(v3)) + self.assertEquals(hash(v2), hash(v3)) + self.assertFalse(hash(v1) == hash(v4)) + self.assertFalse(hash(v2) == hash(v4)) + + def test_eq(self): + v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) + v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) + v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) + self.assertEquals(v1, v2) + self.assertEquals(v1, v3) + self.assertFalse(v2 == v4) + self.assertFalse(v1 == v5) + self.assertFalse(v1 == v6) + + def test_equals(self): + indices = [1, 2, 4] + values = [1., 3., 2.] + self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.])) + def test_conversion(self): # numpy arrays should be automatically upcast to float64 # tests for fix of [SPARK-5089] From 610971ecfe858b1a48ce69b25614afe52bcbe77f Mon Sep 17 00:00:00 2001 From: noelsmith Date: Mon, 14 Sep 2015 21:58:52 -0700 Subject: [PATCH 287/802] [SPARK-10273] Add @since annotation to pyspark.mllib.feature Duplicated the since decorator from pyspark.sql into pyspark (also tweaked to handle functions without docstrings). Added since to methods + "versionadded::" to classes (derived from the git file history in pyspark). Author: noelsmith Closes #8633 from noel-smith/SPARK-10273-since-mllib-feature. --- python/pyspark/mllib/feature.py | 58 ++++++++++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index f921e3ad1a314..7b077b058c3fd 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -30,7 +30,7 @@ from py4j.protocol import Py4JJavaError -from pyspark import SparkContext +from pyspark import SparkContext, since from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper from pyspark.mllib.linalg import ( @@ -84,11 +84,14 @@ class Normalizer(VectorTransformer): >>> nor2 = Normalizer(float("inf")) >>> nor2.transform(v) DenseVector([0.0, 0.5, 1.0]) + + .. versionadded:: 1.2.0 """ def __init__(self, p=2.0): assert p >= 1.0, "p should be greater than 1.0" self.p = float(p) + @since('1.2.0') def transform(self, vector): """ Applies unit length normalization on a vector. @@ -133,7 +136,11 @@ class StandardScalerModel(JavaVectorTransformer): .. note:: Experimental Represents a StandardScaler model that can transform vectors. + + .. versionadded:: 1.2.0 """ + + @since('1.2.0') def transform(self, vector): """ Applies standardization transformation on a vector. @@ -149,6 +156,7 @@ def transform(self, vector): """ return JavaVectorTransformer.transform(self, vector) + @since('1.4.0') def setWithMean(self, withMean): """ Setter of the boolean which decides @@ -157,6 +165,7 @@ def setWithMean(self, withMean): self.call("setWithMean", withMean) return self + @since('1.4.0') def setWithStd(self, withStd): """ Setter of the boolean which decides @@ -189,6 +198,8 @@ class StandardScaler(object): >>> for r in result.collect(): r DenseVector([-0.7071, 0.7071, -0.7071]) DenseVector([0.7071, -0.7071, 0.7071]) + + .. versionadded:: 1.2.0 """ def __init__(self, withMean=False, withStd=True): if not (withMean or withStd): @@ -196,6 +207,7 @@ def __init__(self, withMean=False, withStd=True): self.withMean = withMean self.withStd = withStd + @since('1.2.0') def fit(self, dataset): """ Computes the mean and variance and stores as a model to be used @@ -215,7 +227,11 @@ class ChiSqSelectorModel(JavaVectorTransformer): .. note:: Experimental Represents a Chi Squared selector model. + + .. versionadded:: 1.4.0 """ + + @since('1.4.0') def transform(self, vector): """ Applies transformation on a vector. @@ -245,10 +261,13 @@ class ChiSqSelector(object): SparseVector(1, {0: 6.0}) >>> model.transform(DenseVector([8.0, 9.0, 5.0])) DenseVector([5.0]) + + .. versionadded:: 1.4.0 """ def __init__(self, numTopFeatures): self.numTopFeatures = int(numTopFeatures) + @since('1.4.0') def fit(self, data): """ Returns a ChiSquared feature selector. @@ -265,6 +284,8 @@ def fit(self, data): class PCAModel(JavaVectorTransformer): """ Model fitted by [[PCA]] that can project vectors to a low-dimensional space using PCA. + + .. versionadded:: 1.5.0 """ @@ -281,6 +302,8 @@ class PCA(object): 1.648... >>> pcArray[1] -4.013... + + .. versionadded:: 1.5.0 """ def __init__(self, k): """ @@ -288,6 +311,7 @@ def __init__(self, k): """ self.k = int(k) + @since('1.5.0') def fit(self, data): """ Computes a [[PCAModel]] that contains the principal components of the input vectors. @@ -312,14 +336,18 @@ class HashingTF(object): >>> doc = "a a b b c d".split(" ") >>> htf.transform(doc) SparseVector(100, {...}) + + .. versionadded:: 1.2.0 """ def __init__(self, numFeatures=1 << 20): self.numFeatures = numFeatures + @since('1.2.0') def indexOf(self, term): """ Returns the index of the input term. """ return hash(term) % self.numFeatures + @since('1.2.0') def transform(self, document): """ Transforms the input document (list of terms) to term frequency @@ -339,7 +367,10 @@ def transform(self, document): class IDFModel(JavaVectorTransformer): """ Represents an IDF model that can transform term frequency vectors. + + .. versionadded:: 1.2.0 """ + @since('1.2.0') def transform(self, x): """ Transforms term frequency (TF) vectors to TF-IDF vectors. @@ -358,6 +389,7 @@ def transform(self, x): """ return JavaVectorTransformer.transform(self, x) + @since('1.4.0') def idf(self): """ Returns the current IDF vector. @@ -401,10 +433,13 @@ class IDF(object): DenseVector([0.0, 0.0, 1.3863, 0.863]) >>> model.transform(Vectors.sparse(n, (1, 3), (1.0, 2.0))) SparseVector(4, {1: 0.0, 3: 0.5754}) + + .. versionadded:: 1.2.0 """ def __init__(self, minDocFreq=0): self.minDocFreq = minDocFreq + @since('1.2.0') def fit(self, dataset): """ Computes the inverse document frequency. @@ -420,7 +455,10 @@ def fit(self, dataset): class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader): """ class for Word2Vec model + + .. versionadded:: 1.2.0 """ + @since('1.2.0') def transform(self, word): """ Transforms a word to its vector representation @@ -435,6 +473,7 @@ def transform(self, word): except Py4JJavaError: raise ValueError("%s not found" % word) + @since('1.2.0') def findSynonyms(self, word, num): """ Find synonyms of a word @@ -450,6 +489,7 @@ def findSynonyms(self, word, num): words, similarity = self.call("findSynonyms", word, num) return zip(words, similarity) + @since('1.4.0') def getVectors(self): """ Returns a map of words to their vector representations. @@ -457,7 +497,11 @@ def getVectors(self): return self.call("getVectors") @classmethod + @since('1.5.0') def load(cls, sc, path): + """ + Load a model from the given path. + """ jmodel = sc._jvm.org.apache.spark.mllib.feature \ .Word2VecModel.load(sc._jsc.sc(), path) return Word2VecModel(jmodel) @@ -507,6 +551,8 @@ class Word2Vec(object): ... rmtree(path) ... except OSError: ... pass + + .. versionadded:: 1.2.0 """ def __init__(self): """ @@ -519,6 +565,7 @@ def __init__(self): self.seed = random.randint(0, sys.maxsize) self.minCount = 5 + @since('1.2.0') def setVectorSize(self, vectorSize): """ Sets vector size (default: 100). @@ -526,6 +573,7 @@ def setVectorSize(self, vectorSize): self.vectorSize = vectorSize return self + @since('1.2.0') def setLearningRate(self, learningRate): """ Sets initial learning rate (default: 0.025). @@ -533,6 +581,7 @@ def setLearningRate(self, learningRate): self.learningRate = learningRate return self + @since('1.2.0') def setNumPartitions(self, numPartitions): """ Sets number of partitions (default: 1). Use a small number for @@ -541,6 +590,7 @@ def setNumPartitions(self, numPartitions): self.numPartitions = numPartitions return self + @since('1.2.0') def setNumIterations(self, numIterations): """ Sets number of iterations (default: 1), which should be smaller @@ -549,6 +599,7 @@ def setNumIterations(self, numIterations): self.numIterations = numIterations return self + @since('1.2.0') def setSeed(self, seed): """ Sets random seed. @@ -556,6 +607,7 @@ def setSeed(self, seed): self.seed = seed return self + @since('1.4.0') def setMinCount(self, minCount): """ Sets minCount, the minimum number of times a token must appear @@ -564,6 +616,7 @@ def setMinCount(self, minCount): self.minCount = minCount return self + @since('1.2.0') def fit(self, data): """ Computes the vector representation of each word in vocabulary. @@ -596,10 +649,13 @@ class ElementwiseProduct(VectorTransformer): >>> rdd = sc.parallelize([a, b]) >>> eprod.transform(rdd).collect() [DenseVector([2.0, 2.0, 9.0]), DenseVector([9.0, 6.0, 12.0])] + + .. versionadded:: 1.5.0 """ def __init__(self, scalingVector): self.scalingVector = _convert_to_vector(scalingVector) + @since('1.5.0') def transform(self, vector): """ Computes the Hadamard product of the vector. From a2249359d5b0368318a714b292bb1d0dc70c0e27 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Mon, 14 Sep 2015 21:59:40 -0700 Subject: [PATCH 288/802] [SPARK-10275] [MLLIB] Add @since annotation to pyspark.mllib.random Author: Yu ISHIKAWA Closes #8666 from yu-iskw/SPARK-10275. --- python/pyspark/mllib/random.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 06fbc0eb6aef0..9c733b1332bc0 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -21,6 +21,7 @@ from functools import wraps +from pyspark import since from pyspark.mllib.common import callMLlibFunc @@ -39,9 +40,12 @@ class RandomRDDs(object): """ Generator methods for creating RDDs comprised of i.i.d samples from some distribution. + + .. addedversion:: 1.1.0 """ @staticmethod + @since("1.1.0") def uniformRDD(sc, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the @@ -72,6 +76,7 @@ def uniformRDD(sc, size, numPartitions=None, seed=None): return callMLlibFunc("uniformRDD", sc._jsc, size, numPartitions, seed) @staticmethod + @since("1.1.0") def normalRDD(sc, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the standard normal @@ -100,6 +105,7 @@ def normalRDD(sc, size, numPartitions=None, seed=None): return callMLlibFunc("normalRDD", sc._jsc, size, numPartitions, seed) @staticmethod + @since("1.3.0") def logNormalRDD(sc, mean, std, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the log normal @@ -132,6 +138,7 @@ def logNormalRDD(sc, mean, std, size, numPartitions=None, seed=None): size, numPartitions, seed) @staticmethod + @since("1.1.0") def poissonRDD(sc, mean, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the Poisson @@ -158,6 +165,7 @@ def poissonRDD(sc, mean, size, numPartitions=None, seed=None): return callMLlibFunc("poissonRDD", sc._jsc, float(mean), size, numPartitions, seed) @staticmethod + @since("1.3.0") def exponentialRDD(sc, mean, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the Exponential @@ -184,6 +192,7 @@ def exponentialRDD(sc, mean, size, numPartitions=None, seed=None): return callMLlibFunc("exponentialRDD", sc._jsc, float(mean), size, numPartitions, seed) @staticmethod + @since("1.3.0") def gammaRDD(sc, shape, scale, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the Gamma @@ -216,6 +225,7 @@ def gammaRDD(sc, shape, scale, size, numPartitions=None, seed=None): @staticmethod @toArray + @since("1.1.0") def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn @@ -241,6 +251,7 @@ def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): @staticmethod @toArray + @since("1.1.0") def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn @@ -266,6 +277,7 @@ def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): @staticmethod @toArray + @since("1.3.0") def logNormalVectorRDD(sc, mean, std, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn @@ -300,6 +312,7 @@ def logNormalVectorRDD(sc, mean, std, numRows, numCols, numPartitions=None, seed @staticmethod @toArray + @since("1.1.0") def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn @@ -330,6 +343,7 @@ def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): @staticmethod @toArray + @since("1.3.0") def exponentialVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn @@ -360,6 +374,7 @@ def exponentialVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=No @staticmethod @toArray + @since("1.3.0") def gammaVectorRDD(sc, shape, scale, numRows, numCols, numPartitions=None, seed=None): """ Generates an RDD comprised of vectors containing i.i.d. samples drawn From 833be73314b85b390a9007ed6ed63dc47bbd9e4f Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Mon, 14 Sep 2015 23:40:29 -0700 Subject: [PATCH 289/802] Small fixes to docs Links work now properly + consistent use of *Spark standalone cluster* (Spark uppercase + lowercase the rest -- seems agreed in the other places in the docs). Author: Jacek Laskowski Closes #8759 from jaceklaskowski/docs-submitting-apps. --- docs/submitting-applications.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index e58645274e525..7ea4d6f1a3f8f 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -65,8 +65,8 @@ For Python applications, simply pass a `.py` file in the place of ` Date: Mon, 14 Sep 2015 23:41:06 -0700 Subject: [PATCH 290/802] [SPARK-10598] [DOCS] Comments preceding toMessage method state: "The edge partition is encoded in the lower * 30 bytes of the Int, and the position is encoded in the upper 2 bytes of the Int.". References to bytes should be changed to bits. This contribution is my original work and I license the work to the Spark project under it's open source license. Author: Robin East Closes #8756 from insidedctm/master. --- .../org/apache/spark/graphx/impl/RoutingTablePartition.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala index eb3c997e0f3c0..4f1260a5a67b2 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala @@ -34,7 +34,7 @@ object RoutingTablePartition { /** * A message from an edge partition to a vertex specifying the position in which the edge * partition references the vertex (src, dst, or both). The edge partition is encoded in the lower - * 30 bytes of the Int, and the position is encoded in the upper 2 bytes of the Int. + * 30 bits of the Int, and the position is encoded in the upper 2 bits of the Int. */ type RoutingTableMessage = (VertexId, Int) From 09b7e7c19897549a8622aec095f27b8b38a1a4d3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 15 Sep 2015 00:54:20 -0700 Subject: [PATCH 291/802] Update version to 1.6.0-SNAPSHOT. Author: Reynold Xin Closes #8350 from rxin/1.6. --- R/pkg/DESCRIPTION | 2 +- assembly/pom.xml | 2 +- bagel/pom.xml | 2 +- core/pom.xml | 2 +- core/src/main/scala/org/apache/spark/package.scala | 2 +- docs/_config.yml | 4 ++-- examples/pom.xml | 2 +- external/flume-assembly/pom.xml | 2 +- external/flume-sink/pom.xml | 2 +- external/flume/pom.xml | 2 +- external/kafka-assembly/pom.xml | 2 +- external/kafka/pom.xml | 2 +- external/mqtt-assembly/pom.xml | 2 +- external/mqtt/pom.xml | 2 +- external/twitter/pom.xml | 2 +- external/zeromq/pom.xml | 2 +- extras/java8-tests/pom.xml | 2 +- extras/kinesis-asl-assembly/pom.xml | 2 +- extras/kinesis-asl/pom.xml | 2 +- extras/spark-ganglia-lgpl/pom.xml | 2 +- graphx/pom.xml | 2 +- launcher/pom.xml | 2 +- mllib/pom.xml | 2 +- network/common/pom.xml | 2 +- network/shuffle/pom.xml | 2 +- network/yarn/pom.xml | 2 +- pom.xml | 2 +- project/MimaBuild.scala | 2 +- project/MimaExcludes.scala | 13 +++++++++++-- repl/pom.xml | 2 +- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- sql/hive-thriftserver/pom.xml | 2 +- sql/hive/pom.xml | 2 +- streaming/pom.xml | 2 +- tools/pom.xml | 2 +- unsafe/pom.xml | 2 +- yarn/pom.xml | 2 +- 38 files changed, 49 insertions(+), 40 deletions(-) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index d0d7201f004a2..a3a16c42a6214 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,7 +1,7 @@ Package: SparkR Type: Package Title: R frontend for Spark -Version: 1.5.0 +Version: 1.6.0 Date: 2013-09-09 Author: The Apache Software Foundation Maintainer: Shivaram Venkataraman diff --git a/assembly/pom.xml b/assembly/pom.xml index e9c6d26ccddc7..4b60ee00ffbe5 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/bagel/pom.xml b/bagel/pom.xml index ed5c37e595a96..3baf8d47b4dc7 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/core/pom.xml b/core/pom.xml index a46292c13bcc0..e31d90f608892 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 8ae76c5f72f2e..7515aad09db73 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -43,5 +43,5 @@ package org.apache package object spark { // For package docs only - val SPARK_VERSION = "1.5.0-SNAPSHOT" + val SPARK_VERSION = "1.6.0-SNAPSHOT" } diff --git a/docs/_config.yml b/docs/_config.yml index c0e031a83ba9c..c59cc465ef89d 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,8 +14,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 1.5.0-SNAPSHOT -SPARK_VERSION_SHORT: 1.5.0 +SPARK_VERSION: 1.6.0-SNAPSHOT +SPARK_VERSION_SHORT: 1.6.0 SCALA_BINARY_VERSION: "2.10" SCALA_VERSION: "2.10.4" MESOS_VERSION: 0.21.0 diff --git a/examples/pom.xml b/examples/pom.xml index e6884b09dca94..f5ab2a7fdc098 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index 561ed4babe5d0..dceedcf23ed5b 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 0664cfb2021e1..d7c2ac474a18d 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 14f7daaf417e0..132062f94fb45 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml index 6f4e2a89e9af7..a9ed39ef8c9a0 100644 --- a/external/kafka-assembly/pom.xml +++ b/external/kafka-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index ded863bd985e8..05abd9e2e6810 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/mqtt-assembly/pom.xml b/external/mqtt-assembly/pom.xml index 8412600633734..89713a28ca6a8 100644 --- a/external/mqtt-assembly/pom.xml +++ b/external/mqtt-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 69b309876a0db..05e6338a08b0a 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 178ae8de13b57..244ad58ae9593 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 37bfd10d43663..171df8682c848 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 3636a9037d43f..81794a8536318 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/extras/kinesis-asl-assembly/pom.xml b/extras/kinesis-asl-assembly/pom.xml index 51af3e6f2225f..61ba4787fbf90 100644 --- a/extras/kinesis-asl-assembly/pom.xml +++ b/extras/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 521b53e230c4a..6dd8ff69c2943 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index 478d0019a25f0..87a4f05a05961 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 853dea9a7795e..202fc19002d12 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/launcher/pom.xml b/launcher/pom.xml index 2fd768d8119c4..ed38e66aa2467 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index a5db14407b4fc..22c0c6008ba37 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/network/common/pom.xml b/network/common/pom.xml index 4141fcb8267a5..1cc054a8936c5 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index 3d2edf9d94515..7a66c968041ce 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml index a99f7c4392d3d..e745180eace78 100644 --- a/network/yarn/pom.xml +++ b/network/yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/pom.xml b/pom.xml index 421357e141572..6535994641145 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index f16bf989f200b..519052620246f 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -91,7 +91,7 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - val previousSparkVersion = "1.4.0" + val previousSparkVersion = "1.5.0" val fullId = "spark-" + projectRef.project + "_2.10" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 3b8b6c8ffa375..87b141cd3b058 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -35,8 +35,17 @@ object MimaExcludes { def excludes(version: String) = version match { case v if v.startsWith("1.6") => Seq( - MimaBuild.excludeSparkPackage("network") - ) + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("network"), + // These are needed if checking against the sbt build, since they are part of + // the maven-generated artifacts in 1.3. + excludePackage("org.spark-project.jetty"), + MimaBuild.excludeSparkPackage("unused"), + // SQL execution is considered private. + excludePackage("org.apache.spark.sql.execution") + ) ++ + MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++ + MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") case v if v.startsWith("1.5") => Seq( MimaBuild.excludeSparkPackage("network"), diff --git a/repl/pom.xml b/repl/pom.xml index a5a0f1fc2c857..5cf416a4a5448 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 75ab575dfde83..6cfd53e868f83 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 349007789f634..465aa3a3888c2 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 3566c87dd248c..f7fe085f34d84 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index be1607476e254..ac67fe5f47be9 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../../pom.xml diff --git a/streaming/pom.xml b/streaming/pom.xml index 697895e72fe5b..5cc9001b0e9ab 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index 298ee2348b58e..1e64f280e5bed 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 89475ee3cf5a1..066abe92e51c0 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml diff --git a/yarn/pom.xml b/yarn/pom.xml index f6737695307a2..d8e4a4bbead81 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.5.0-SNAPSHOT + 1.6.0-SNAPSHOT ../pom.xml From c35fdcb7e9c01271ce560dba4e0bd37569c8f5d1 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 15 Sep 2015 09:58:49 -0700 Subject: [PATCH 292/802] [SPARK-10491] [MLLIB] move RowMatrix.dspr to BLAS jira: https://issues.apache.org/jira/browse/SPARK-10491 We implemented dspr with sparse vector support in `RowMatrix`. This method is also used in WeightedLeastSquares and other places. It would be useful to move it to `linalg.BLAS`. Let me know if new UT needed. Author: Yuhao Yang Closes #8663 from hhbyyh/movedspr. --- .../spark/ml/optim/WeightedLeastSquares.scala | 4 +- .../org/apache/spark/mllib/linalg/BLAS.scala | 44 +++++++++++++++++++ .../mllib/linalg/distributed/RowMatrix.scala | 40 +---------------- .../apache/spark/mllib/linalg/BLASSuite.scala | 25 +++++++++++ 4 files changed, 72 insertions(+), 41 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index a99e2ac4c6913..0ff8931b0bab4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -88,7 +88,7 @@ private[ml] class WeightedLeastSquares( if (fitIntercept) { // shift centers // A^T A - aBar aBar^T - RowMatrix.dspr(-1.0, aBar, aaValues) + BLAS.spr(-1.0, aBar, aaValues) // A^T b - bBar aBar BLAS.axpy(-bBar, aBar, abBar) } @@ -203,7 +203,7 @@ private[ml] object WeightedLeastSquares { bbSum += w * b * b BLAS.axpy(w, a, aSum) BLAS.axpy(w * b, a, abSum) - RowMatrix.dspr(w, a, aaSum.values) + BLAS.spr(w, a, aaSum) this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 9ee81eda8a8c0..df9f4ae145b88 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -236,6 +236,50 @@ private[spark] object BLAS extends Serializable with Logging { _nativeBLAS } + /** + * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR. + * + * @param U the upper triangular part of the matrix in a [[DenseVector]](column major) + */ + def spr(alpha: Double, v: Vector, U: DenseVector): Unit = { + spr(alpha, v, U.values) + } + + /** + * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR. + * + * @param U the upper triangular part of the matrix packed in an array (column major) + */ + def spr(alpha: Double, v: Vector, U: Array[Double]): Unit = { + val n = v.size + v match { + case DenseVector(values) => + NativeBLAS.dspr("U", n, alpha, values, 1, U) + case SparseVector(size, indices, values) => + val nnz = indices.length + var colStartIdx = 0 + var prevCol = 0 + var col = 0 + var j = 0 + var i = 0 + var av = 0.0 + while (j < nnz) { + col = indices(j) + // Skip empty columns. + colStartIdx += (col - prevCol) * (col + prevCol + 1) / 2 + col = indices(j) + av = alpha * values(j) + i = 0 + while (i <= j) { + U(colStartIdx + indices(i)) += av * values(i) + i += 1 + } + j += 1 + prevCol = col + } + } + } + /** * A := alpha * x * x^T^ + A * @param alpha a real scalar that will be multiplied to x * x^T^. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 83779ac88989b..e55ef26858adb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -24,7 +24,6 @@ import scala.collection.mutable.ListBuffer import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV, axpy => brzAxpy, svd => brzSvd, MatrixSingularException, inv} import breeze.numerics.{sqrt => brzSqrt} -import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging import org.apache.spark.SparkContext._ @@ -123,7 +122,7 @@ class RowMatrix @Since("1.0.0") ( // Compute the upper triangular part of the gram matrix. val GU = rows.treeAggregate(new BDV[Double](new Array[Double](nt)))( seqOp = (U, v) => { - RowMatrix.dspr(1.0, v, U.data) + BLAS.spr(1.0, v, U.data) U }, combOp = (U1, U2) => U1 += U2) @@ -673,43 +672,6 @@ class RowMatrix @Since("1.0.0") ( @Experimental object RowMatrix { - /** - * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's DSPR. - * - * @param U the upper triangular part of the matrix packed in an array (column major) - */ - // TODO: SPARK-10491 - move this method to linalg.BLAS - private[spark] def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = { - // TODO: Find a better home (breeze?) for this method. - val n = v.size - v match { - case DenseVector(values) => - blas.dspr("U", n, alpha, values, 1, U) - case SparseVector(size, indices, values) => - val nnz = indices.length - var colStartIdx = 0 - var prevCol = 0 - var col = 0 - var j = 0 - var i = 0 - var av = 0.0 - while (j < nnz) { - col = indices(j) - // Skip empty columns. - colStartIdx += (col - prevCol) * (col + prevCol + 1) / 2 - col = indices(j) - av = alpha * values(j) - i = 0 - while (i <= j) { - U(colStartIdx + indices(i)) += av * values(i) - i += 1 - } - j += 1 - prevCol = col - } - } - } - /** * Fills a full square matrix from its upper triangular part. */ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 8db5c8424abe9..96e5ffef7a131 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -126,6 +126,31 @@ class BLASSuite extends SparkFunSuite { } } + test("spr") { + // test dense vector + val alpha = 0.1 + val x = new DenseVector(Array(1.0, 2, 2.1, 4)) + val U = new DenseVector(Array(1.0, 2, 2, 3, 3, 3, 4, 4, 4, 4)) + val expected = new DenseVector(Array(1.1, 2.2, 2.4, 3.21, 3.42, 3.441, 4.4, 4.8, 4.84, 5.6)) + + spr(alpha, x, U) + assert(U ~== expected absTol 1e-9) + + val matrix33 = new DenseVector(Array(1.0, 2, 3, 4, 5)) + withClue("Size of vector must match the rank of matrix") { + intercept[Exception] { + spr(alpha, x, matrix33) + } + } + + // test sparse vector + val sv = new SparseVector(4, Array(0, 3), Array(1.0, 2)) + val U2 = new DenseVector(Array(1.0, 2, 2, 3, 3, 3, 4, 4, 4, 4)) + spr(0.1, sv, U2) + val expectedSparse = new DenseVector(Array(1.1, 2.0, 2.0, 3.0, 3.0, 3.0, 4.2, 4.0, 4.0, 4.4)) + assert(U2 ~== expectedSparse absTol 1e-15) + } + test("syr") { val dA = new DenseMatrix(4, 4, Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0, 3.1, 4.6, 3.0, 0.8)) From 8abef21dac1a6538c4e4e0140323b83d804d602b Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 15 Sep 2015 10:45:02 -0700 Subject: [PATCH 293/802] [SPARK-10300] [BUILD] [TESTS] Add support for test tags in run-tests.py. This change does two things: - tag a few tests and adds the mechanism in the build to be able to disable those tags, both in maven and sbt, for both junit and scalatest suites. - add some logic to run-tests.py to disable some tags depending on what files have changed; that's used to disable expensive tests when a module hasn't explicitly been changed, to speed up testing for changes that don't directly affect those modules. Author: Marcelo Vanzin Closes #8437 from vanzin/test-tags. --- core/pom.xml | 10 ------- dev/run-tests.py | 19 ++++++++++++-- dev/sparktestsupport/modules.py | 24 ++++++++++++++++- external/flume/pom.xml | 10 ------- external/kafka/pom.xml | 10 ------- external/mqtt/pom.xml | 10 ------- external/twitter/pom.xml | 10 ------- external/zeromq/pom.xml | 10 ------- extras/java8-tests/pom.xml | 10 ------- extras/kinesis-asl/pom.xml | 5 ---- launcher/pom.xml | 5 ---- mllib/pom.xml | 10 ------- network/common/pom.xml | 10 ------- network/shuffle/pom.xml | 10 ------- pom.xml | 17 ++++++++++-- project/SparkBuild.scala | 13 ++++++++-- sql/core/pom.xml | 5 ---- .../execution/HiveCompatibilitySuite.scala | 2 ++ sql/hive/pom.xml | 5 ---- .../spark/sql/hive/ExtendedHiveTest.java | 26 +++++++++++++++++++ .../spark/sql/hive/client/VersionsSuite.scala | 2 ++ streaming/pom.xml | 10 ------- unsafe/pom.xml | 10 ------- .../spark/deploy/yarn/ExtendedYarnTest.java | 26 +++++++++++++++++++ .../spark/deploy/yarn/YarnClusterSuite.scala | 1 + .../yarn/YarnShuffleIntegrationSuite.scala | 1 + 26 files changed, 124 insertions(+), 147 deletions(-) create mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java create mode 100644 yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java diff --git a/core/pom.xml b/core/pom.xml index e31d90f608892..8a20181096223 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -331,16 +331,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - org.apache.curator curator-test diff --git a/dev/run-tests.py b/dev/run-tests.py index d8b22e1665e7b..1a816585187d9 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -118,6 +118,14 @@ def determine_modules_to_test(changed_modules): return modules_to_test.union(set(changed_modules)) +def determine_tags_to_exclude(changed_modules): + tags = [] + for m in modules.all_modules: + if m not in changed_modules: + tags += m.test_tags + return tags + + # ------------------------------------------------------------------------------------------------- # Functions for working with subprocesses and shell tools # ------------------------------------------------------------------------------------------------- @@ -369,6 +377,7 @@ def detect_binary_inop_with_mima(): def run_scala_tests_maven(test_profiles): mvn_test_goals = ["test", "--fail-at-end"] + profiles_and_goals = test_profiles + mvn_test_goals print("[info] Running Spark tests using Maven with these arguments: ", @@ -392,7 +401,7 @@ def run_scala_tests_sbt(test_modules, test_profiles): exec_sbt(profiles_and_goals) -def run_scala_tests(build_tool, hadoop_version, test_modules): +def run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags): """Function to properly execute all tests passed in as a set from the `determine_test_suites` function""" set_title_and_block("Running Spark unit tests", "BLOCK_SPARK_UNIT_TESTS") @@ -401,6 +410,10 @@ def run_scala_tests(build_tool, hadoop_version, test_modules): test_profiles = get_hadoop_profiles(hadoop_version) + \ list(set(itertools.chain.from_iterable(m.build_profile_flags for m in test_modules))) + + if excluded_tags: + test_profiles += ['-Dtest.exclude.tags=' + ",".join(excluded_tags)] + if build_tool == "maven": run_scala_tests_maven(test_profiles) else: @@ -500,8 +513,10 @@ def main(): target_branch = os.environ["ghprbTargetBranch"] changed_files = identify_changed_files_from_git_commits("HEAD", target_branch=target_branch) changed_modules = determine_modules_for_files(changed_files) + excluded_tags = determine_tags_to_exclude(changed_modules) if not changed_modules: changed_modules = [modules.root] + excluded_tags = [] print("[info] Found the following changed modules:", ", ".join(x.name for x in changed_modules)) @@ -541,7 +556,7 @@ def main(): detect_binary_inop_with_mima() # run the test suites - run_scala_tests(build_tool, hadoop_version, test_modules) + run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags) modules_with_python_tests = [m for m in test_modules if m.python_test_goals] if modules_with_python_tests: diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 346452f3174e4..65397f1f3e0bc 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -31,7 +31,7 @@ class Module(object): def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={}, sbt_test_goals=(), python_test_goals=(), blacklisted_python_implementations=(), - should_run_r_tests=False): + test_tags=(), should_run_r_tests=False): """ Define a new module. @@ -50,6 +50,8 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= :param blacklisted_python_implementations: A set of Python implementations that are not supported by this module's Python components. The values in this set should match strings returned by Python's `platform.python_implementation()`. + :param test_tags A set of tags that will be excluded when running unit tests if the module + is not explicitly changed. :param should_run_r_tests: If true, changes in this module will trigger all R tests. """ self.name = name @@ -60,6 +62,7 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= self.environ = environ self.python_test_goals = python_test_goals self.blacklisted_python_implementations = blacklisted_python_implementations + self.test_tags = test_tags self.should_run_r_tests = should_run_r_tests self.dependent_modules = set() @@ -85,6 +88,9 @@ def contains_file(self, filename): "catalyst/test", "sql/test", "hive/test", + ], + test_tags=[ + "org.apache.spark.sql.hive.ExtendedHiveTest" ] ) @@ -398,6 +404,22 @@ def contains_file(self, filename): ) +yarn = Module( + name="yarn", + dependencies=[], + source_file_regexes=[ + "yarn/", + "network/yarn/", + ], + sbt_test_goals=[ + "yarn/test", + "network-yarn/test", + ], + test_tags=[ + "org.apache.spark.deploy.yarn.ExtendedYarnTest" + ] +) + # The root module is a dummy module which is used to run all of the tests. # No other modules should directly depend on this module. root = Module( diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 132062f94fb45..3154e36c21ef5 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -66,16 +66,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - target/scala-${scala.binary.version}/classes diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 05abd9e2e6810..7d0d46dadc727 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -86,16 +86,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - target/scala-${scala.binary.version}/classes diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 05e6338a08b0a..913c47d33f488 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -58,16 +58,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - org.apache.activemq activemq-core diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 244ad58ae9593..9137bf25ee8ae 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -58,16 +58,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - target/scala-${scala.binary.version}/classes diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 171df8682c848..6fec4f0e8a0f9 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -57,16 +57,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - target/scala-${scala.binary.version}/classes diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 81794a8536318..dba3dda8a9562 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -58,16 +58,6 @@ test-jar test - - junit - junit - test - - - com.novocode - junit-interface - test - diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 6dd8ff69c2943..760f183a2ef37 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -74,11 +74,6 @@ scalacheck_${scala.binary.version} test - - com.novocode - junit-interface - test - target/scala-${scala.binary.version}/classes diff --git a/launcher/pom.xml b/launcher/pom.xml index ed38e66aa2467..80696280a1d18 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -42,11 +42,6 @@ log4j test - - junit - junit - test - org.mockito mockito-core diff --git a/mllib/pom.xml b/mllib/pom.xml index 22c0c6008ba37..5dedacb38874e 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -94,16 +94,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - - - com.novocode - junit-interface - test - org.mockito mockito-core diff --git a/network/common/pom.xml b/network/common/pom.xml index 1cc054a8936c5..9c12cca0df609 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -64,16 +64,6 @@ - - junit - junit - test - - - com.novocode - junit-interface - test - log4j log4j diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index 7a66c968041ce..e4f4c57b683c8 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -78,16 +78,6 @@ test-jar test - - junit - junit - test - - - com.novocode - junit-interface - test - log4j log4j diff --git a/pom.xml b/pom.xml index 6535994641145..2927d3e107563 100644 --- a/pom.xml +++ b/pom.xml @@ -181,6 +181,7 @@ 0.9.2 ${java.home} + @@ -1952,6 +1964,7 @@ __not_used__ + ${test.exclude.tags} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 901cfa538d23e..d80d300f1c3b2 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -567,11 +567,20 @@ object TestSettings { javaOptions in Test ++= "-Xmx3g -Xss4096k -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g" .split(" ").toSeq, javaOptions += "-Xmx3g", + // Exclude tags defined in a system property + testOptions in Test += Tests.Argument(TestFrameworks.ScalaTest, + sys.props.get("test.exclude.tags").map { tags => + tags.split(",").flatMap { tag => Seq("-l", tag) }.toSeq + }.getOrElse(Nil): _*), + testOptions in Test += Tests.Argument(TestFrameworks.JUnit, + sys.props.get("test.exclude.tags").map { tags => + Seq("--exclude-categories=" + tags) + }.getOrElse(Nil): _*), // Show full stack trace and duration in test cases. testOptions in Test += Tests.Argument("-oDF"), - testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), + testOptions in Test += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), // Enable Junit testing. - libraryDependencies += "com.novocode" % "junit-interface" % "0.9" % "test", + libraryDependencies += "com.novocode" % "junit-interface" % "0.11" % "test", // Only allow one test at a time, even across projects, since they run in the same JVM parallelExecution in Test := false, // Make sure the test temp directory exists. diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 465aa3a3888c2..fa6732db183d8 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -73,11 +73,6 @@ jackson-databind ${fasterxml.jackson.version} - - junit - junit - test - org.scalacheck scalacheck_${scala.binary.version} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index ab309e0a1d36b..ffc4c32794ca4 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -24,11 +24,13 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.hive.ExtendedHiveTest import org.apache.spark.sql.hive.test.TestHive /** * Runs the test cases that are included in the hive distribution. */ +@ExtendedHiveTest class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // TODO: bundle in jar files... get from classpath private lazy val hiveQueryDir = TestHive.getHiveFile( diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index ac67fe5f47be9..82cfeb2bb95d3 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -160,11 +160,6 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - org.apache.spark spark-sql_${scala.binary.version} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java new file mode 100644 index 0000000000000..e2183183fb559 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive; + +import java.lang.annotation.*; +import org.scalatest.TagAnnotation; + +@TagAnnotation +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.METHOD, ElementType.TYPE}) +public @interface ExtendedHiveTest { } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index f0bb77092c0cf..888d1b7b45532 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo} import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.hive.ExtendedHiveTest import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.Utils @@ -32,6 +33,7 @@ import org.apache.spark.util.Utils * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionality * is not fully tested. */ +@ExtendedHiveTest class VersionsSuite extends SparkFunSuite with Logging { // Do not use a temp path here to speed up subsequent executions of the unit test during diff --git a/streaming/pom.xml b/streaming/pom.xml index 5cc9001b0e9ab..1e6ee009ca6d5 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -84,21 +84,11 @@ scalacheck_${scala.binary.version} test - - junit - junit - test - org.seleniumhq.selenium selenium-java test - - com.novocode - junit-interface - test - target/scala-${scala.binary.version}/classes diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 066abe92e51c0..4e8b9a84bb67f 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -55,16 +55,6 @@ - - junit - junit - test - - - com.novocode - junit-interface - test - org.mockito mockito-core diff --git a/yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java b/yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java new file mode 100644 index 0000000000000..7a8f2fe979c1f --- /dev/null +++ b/yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn; + +import java.lang.annotation.*; +import org.scalatest.TagAnnotation; + +@TagAnnotation +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.METHOD, ElementType.TYPE}) +public @interface ExtendedYarnTest { } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index b5a42fd6afd98..105c3090d489d 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.util.Utils * applications, and require the Spark assembly to be built before they can be successfully * run. */ +@ExtendedYarnTest class YarnClusterSuite extends BaseYarnClusterSuite { override def newYarnConfig(): YarnConfiguration = new YarnConfiguration() diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala index 8d9c9b3004eda..4700e2428df08 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.network.yarn.{YarnShuffleService, YarnTestAccessor} /** * Integration test for the external shuffle service with a yarn mini-cluster */ +@ExtendedYarnTest class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { override def newYarnConfig(): YarnConfiguration = { From 7ca30b505c3561dc2832b463be4c6301a90380e4 Mon Sep 17 00:00:00 2001 From: noelsmith Date: Tue, 15 Sep 2015 12:23:20 -0700 Subject: [PATCH 294/802] [PYSPARK] [MLLIB] [DOCS] Replaced addversion with versionadded in mllib.random Missed this when reviewing `pyspark.mllib.random` for SPARK-10275. Author: noelsmith Closes #8773 from noel-smith/mllib-random-versionadded-fix. --- python/pyspark/mllib/random.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 9c733b1332bc0..6a3c643b66417 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -41,7 +41,7 @@ class RandomRDDs(object): Generator methods for creating RDDs comprised of i.i.d samples from some distribution. - .. addedversion:: 1.1.0 + .. versionadded:: 1.1.0 """ @staticmethod From 0d9ab016755d5b56ce4043f229602169fd752e88 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 15 Sep 2015 12:25:31 -0700 Subject: [PATCH 295/802] Closes #8738 Closes #8767 Closes #2491 Closes #6795 Closes #2096 Closes #7722 From 416003b26401894ec712e1a5291a92adfbc5af01 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Tue, 15 Sep 2015 20:42:33 +0100 Subject: [PATCH 296/802] [DOCS] Small fixes to Spark on Yarn doc * a follow-up to 16b6d18613e150c7038c613992d80a7828413e66 as `--num-executors` flag is not suppported. * links + formatting Author: Jacek Laskowski Closes #8762 from jaceklaskowski/docs-spark-on-yarn. --- docs/running-on-yarn.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 5159ef9e3394e..d1244323edfff 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -18,16 +18,16 @@ Spark application's configuration (driver, executors, and the AM when running in There are two deploy modes that can be used to launch Spark applications on YARN. In `yarn-cluster` mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In `yarn-client` mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. -Unlike in Spark standalone and Mesos mode, in which the master's address is specified in the `--master` parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the `--master` parameter is `yarn-client` or `yarn-cluster`. +Unlike [Spark standalone](spark-standalone.html) and [Mesos](running-on-mesos.html) modes, in which the master's address is specified in the `--master` parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the `--master` parameter is `yarn-client` or `yarn-cluster`. + To launch a Spark application in `yarn-cluster` mode: - `$ ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options]` + $ ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options] For example: $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ --master yarn-cluster \ - --num-executors 3 \ --driver-memory 4g \ --executor-memory 2g \ --executor-cores 1 \ @@ -37,7 +37,7 @@ For example: The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. -To launch a Spark application in `yarn-client` mode, do the same, but replace `yarn-cluster` with `yarn-client`. To run spark-shell: +To launch a Spark application in `yarn-client` mode, do the same, but replace `yarn-cluster` with `yarn-client`. The following shows how you can run `spark-shell` in `yarn-client` mode: $ ./bin/spark-shell --master yarn-client @@ -54,8 +54,8 @@ In `yarn-cluster` mode, the driver runs on a different machine than the client, # Preparations -Running Spark-on-YARN requires a binary distribution of Spark which is built with YARN support. -Binary distributions can be downloaded from the Spark project website. +Running Spark on YARN requires a binary distribution of Spark which is built with YARN support. +Binary distributions can be downloaded from the [downloads page](http://spark.apache.org/downloads.html) of the project website. To build Spark yourself, refer to [Building Spark](building-spark.html). # Configuration From b42059d2efdf3322334694205a6d951bcc291644 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 15 Sep 2015 13:03:38 -0700 Subject: [PATCH 297/802] Revert "[SPARK-10300] [BUILD] [TESTS] Add support for test tags in run-tests.py." This reverts commit 8abef21dac1a6538c4e4e0140323b83d804d602b. --- core/pom.xml | 10 +++++++ dev/run-tests.py | 19 ++------------ dev/sparktestsupport/modules.py | 24 +---------------- external/flume/pom.xml | 10 +++++++ external/kafka/pom.xml | 10 +++++++ external/mqtt/pom.xml | 10 +++++++ external/twitter/pom.xml | 10 +++++++ external/zeromq/pom.xml | 10 +++++++ extras/java8-tests/pom.xml | 10 +++++++ extras/kinesis-asl/pom.xml | 5 ++++ launcher/pom.xml | 5 ++++ mllib/pom.xml | 10 +++++++ network/common/pom.xml | 10 +++++++ network/shuffle/pom.xml | 10 +++++++ pom.xml | 17 ++---------- project/SparkBuild.scala | 13 ++-------- sql/core/pom.xml | 5 ++++ .../execution/HiveCompatibilitySuite.scala | 2 -- sql/hive/pom.xml | 5 ++++ .../spark/sql/hive/ExtendedHiveTest.java | 26 ------------------- .../spark/sql/hive/client/VersionsSuite.scala | 2 -- streaming/pom.xml | 10 +++++++ unsafe/pom.xml | 10 +++++++ .../spark/deploy/yarn/ExtendedYarnTest.java | 26 ------------------- .../spark/deploy/yarn/YarnClusterSuite.scala | 1 - .../yarn/YarnShuffleIntegrationSuite.scala | 1 - 26 files changed, 147 insertions(+), 124 deletions(-) delete mode 100644 sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java delete mode 100644 yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java diff --git a/core/pom.xml b/core/pom.xml index 8a20181096223..e31d90f608892 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -331,6 +331,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + org.apache.curator curator-test diff --git a/dev/run-tests.py b/dev/run-tests.py index 1a816585187d9..d8b22e1665e7b 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -118,14 +118,6 @@ def determine_modules_to_test(changed_modules): return modules_to_test.union(set(changed_modules)) -def determine_tags_to_exclude(changed_modules): - tags = [] - for m in modules.all_modules: - if m not in changed_modules: - tags += m.test_tags - return tags - - # ------------------------------------------------------------------------------------------------- # Functions for working with subprocesses and shell tools # ------------------------------------------------------------------------------------------------- @@ -377,7 +369,6 @@ def detect_binary_inop_with_mima(): def run_scala_tests_maven(test_profiles): mvn_test_goals = ["test", "--fail-at-end"] - profiles_and_goals = test_profiles + mvn_test_goals print("[info] Running Spark tests using Maven with these arguments: ", @@ -401,7 +392,7 @@ def run_scala_tests_sbt(test_modules, test_profiles): exec_sbt(profiles_and_goals) -def run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags): +def run_scala_tests(build_tool, hadoop_version, test_modules): """Function to properly execute all tests passed in as a set from the `determine_test_suites` function""" set_title_and_block("Running Spark unit tests", "BLOCK_SPARK_UNIT_TESTS") @@ -410,10 +401,6 @@ def run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags): test_profiles = get_hadoop_profiles(hadoop_version) + \ list(set(itertools.chain.from_iterable(m.build_profile_flags for m in test_modules))) - - if excluded_tags: - test_profiles += ['-Dtest.exclude.tags=' + ",".join(excluded_tags)] - if build_tool == "maven": run_scala_tests_maven(test_profiles) else: @@ -513,10 +500,8 @@ def main(): target_branch = os.environ["ghprbTargetBranch"] changed_files = identify_changed_files_from_git_commits("HEAD", target_branch=target_branch) changed_modules = determine_modules_for_files(changed_files) - excluded_tags = determine_tags_to_exclude(changed_modules) if not changed_modules: changed_modules = [modules.root] - excluded_tags = [] print("[info] Found the following changed modules:", ", ".join(x.name for x in changed_modules)) @@ -556,7 +541,7 @@ def main(): detect_binary_inop_with_mima() # run the test suites - run_scala_tests(build_tool, hadoop_version, test_modules, excluded_tags) + run_scala_tests(build_tool, hadoop_version, test_modules) modules_with_python_tests = [m for m in test_modules if m.python_test_goals] if modules_with_python_tests: diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 65397f1f3e0bc..346452f3174e4 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -31,7 +31,7 @@ class Module(object): def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={}, sbt_test_goals=(), python_test_goals=(), blacklisted_python_implementations=(), - test_tags=(), should_run_r_tests=False): + should_run_r_tests=False): """ Define a new module. @@ -50,8 +50,6 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= :param blacklisted_python_implementations: A set of Python implementations that are not supported by this module's Python components. The values in this set should match strings returned by Python's `platform.python_implementation()`. - :param test_tags A set of tags that will be excluded when running unit tests if the module - is not explicitly changed. :param should_run_r_tests: If true, changes in this module will trigger all R tests. """ self.name = name @@ -62,7 +60,6 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= self.environ = environ self.python_test_goals = python_test_goals self.blacklisted_python_implementations = blacklisted_python_implementations - self.test_tags = test_tags self.should_run_r_tests = should_run_r_tests self.dependent_modules = set() @@ -88,9 +85,6 @@ def contains_file(self, filename): "catalyst/test", "sql/test", "hive/test", - ], - test_tags=[ - "org.apache.spark.sql.hive.ExtendedHiveTest" ] ) @@ -404,22 +398,6 @@ def contains_file(self, filename): ) -yarn = Module( - name="yarn", - dependencies=[], - source_file_regexes=[ - "yarn/", - "network/yarn/", - ], - sbt_test_goals=[ - "yarn/test", - "network-yarn/test", - ], - test_tags=[ - "org.apache.spark.deploy.yarn.ExtendedYarnTest" - ] -) - # The root module is a dummy module which is used to run all of the tests. # No other modules should directly depend on this module. root = Module( diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 3154e36c21ef5..132062f94fb45 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -66,6 +66,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + target/scala-${scala.binary.version}/classes diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 7d0d46dadc727..05abd9e2e6810 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -86,6 +86,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + target/scala-${scala.binary.version}/classes diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 913c47d33f488..05e6338a08b0a 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -58,6 +58,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + org.apache.activemq activemq-core diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 9137bf25ee8ae..244ad58ae9593 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -58,6 +58,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + target/scala-${scala.binary.version}/classes diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 6fec4f0e8a0f9..171df8682c848 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -57,6 +57,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + target/scala-${scala.binary.version}/classes diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index dba3dda8a9562..81794a8536318 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -58,6 +58,16 @@ test-jar test + + junit + junit + test + + + com.novocode + junit-interface + test + diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 760f183a2ef37..6dd8ff69c2943 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -74,6 +74,11 @@ scalacheck_${scala.binary.version} test + + com.novocode + junit-interface + test + target/scala-${scala.binary.version}/classes diff --git a/launcher/pom.xml b/launcher/pom.xml index 80696280a1d18..ed38e66aa2467 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -42,6 +42,11 @@ log4j test + + junit + junit + test + org.mockito mockito-core diff --git a/mllib/pom.xml b/mllib/pom.xml index 5dedacb38874e..22c0c6008ba37 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -94,6 +94,16 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + + + com.novocode + junit-interface + test + org.mockito mockito-core diff --git a/network/common/pom.xml b/network/common/pom.xml index 9c12cca0df609..1cc054a8936c5 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -64,6 +64,16 @@ + + junit + junit + test + + + com.novocode + junit-interface + test + log4j log4j diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index e4f4c57b683c8..7a66c968041ce 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -78,6 +78,16 @@ test-jar test + + junit + junit + test + + + com.novocode + junit-interface + test + log4j log4j diff --git a/pom.xml b/pom.xml index 2927d3e107563..6535994641145 100644 --- a/pom.xml +++ b/pom.xml @@ -181,7 +181,6 @@ 0.9.2 ${java.home} - @@ -1964,7 +1952,6 @@ __not_used__ - ${test.exclude.tags} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index d80d300f1c3b2..901cfa538d23e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -567,20 +567,11 @@ object TestSettings { javaOptions in Test ++= "-Xmx3g -Xss4096k -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g" .split(" ").toSeq, javaOptions += "-Xmx3g", - // Exclude tags defined in a system property - testOptions in Test += Tests.Argument(TestFrameworks.ScalaTest, - sys.props.get("test.exclude.tags").map { tags => - tags.split(",").flatMap { tag => Seq("-l", tag) }.toSeq - }.getOrElse(Nil): _*), - testOptions in Test += Tests.Argument(TestFrameworks.JUnit, - sys.props.get("test.exclude.tags").map { tags => - Seq("--exclude-categories=" + tags) - }.getOrElse(Nil): _*), // Show full stack trace and duration in test cases. testOptions in Test += Tests.Argument("-oDF"), - testOptions in Test += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), + testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), // Enable Junit testing. - libraryDependencies += "com.novocode" % "junit-interface" % "0.11" % "test", + libraryDependencies += "com.novocode" % "junit-interface" % "0.9" % "test", // Only allow one test at a time, even across projects, since they run in the same JVM parallelExecution in Test := false, // Make sure the test temp directory exists. diff --git a/sql/core/pom.xml b/sql/core/pom.xml index fa6732db183d8..465aa3a3888c2 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -73,6 +73,11 @@ jackson-databind ${fasterxml.jackson.version} + + junit + junit + test + org.scalacheck scalacheck_${scala.binary.version} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index ffc4c32794ca4..ab309e0a1d36b 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -24,13 +24,11 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.hive.ExtendedHiveTest import org.apache.spark.sql.hive.test.TestHive /** * Runs the test cases that are included in the hive distribution. */ -@ExtendedHiveTest class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // TODO: bundle in jar files... get from classpath private lazy val hiveQueryDir = TestHive.getHiveFile( diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 82cfeb2bb95d3..ac67fe5f47be9 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -160,6 +160,11 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + org.apache.spark spark-sql_${scala.binary.version} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java deleted file mode 100644 index e2183183fb559..0000000000000 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/ExtendedHiveTest.java +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive; - -import java.lang.annotation.*; -import org.scalatest.TagAnnotation; - -@TagAnnotation -@Retention(RetentionPolicy.RUNTIME) -@Target({ElementType.METHOD, ElementType.TYPE}) -public @interface ExtendedHiveTest { } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 888d1b7b45532..f0bb77092c0cf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo} import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.hive.ExtendedHiveTest import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.Utils @@ -33,7 +32,6 @@ import org.apache.spark.util.Utils * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionality * is not fully tested. */ -@ExtendedHiveTest class VersionsSuite extends SparkFunSuite with Logging { // Do not use a temp path here to speed up subsequent executions of the unit test during diff --git a/streaming/pom.xml b/streaming/pom.xml index 1e6ee009ca6d5..5cc9001b0e9ab 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -84,11 +84,21 @@ scalacheck_${scala.binary.version} test + + junit + junit + test + org.seleniumhq.selenium selenium-java test + + com.novocode + junit-interface + test + target/scala-${scala.binary.version}/classes diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 4e8b9a84bb67f..066abe92e51c0 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -55,6 +55,16 @@ + + junit + junit + test + + + com.novocode + junit-interface + test + org.mockito mockito-core diff --git a/yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java b/yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java deleted file mode 100644 index 7a8f2fe979c1f..0000000000000 --- a/yarn/src/test/java/org/apache/spark/deploy/yarn/ExtendedYarnTest.java +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn; - -import java.lang.annotation.*; -import org.scalatest.TagAnnotation; - -@TagAnnotation -@Retention(RetentionPolicy.RUNTIME) -@Target({ElementType.METHOD, ElementType.TYPE}) -public @interface ExtendedYarnTest { } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 105c3090d489d..b5a42fd6afd98 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -39,7 +39,6 @@ import org.apache.spark.util.Utils * applications, and require the Spark assembly to be built before they can be successfully * run. */ -@ExtendedYarnTest class YarnClusterSuite extends BaseYarnClusterSuite { override def newYarnConfig(): YarnConfiguration = new YarnConfiguration() diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala index 4700e2428df08..8d9c9b3004eda 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala @@ -32,7 +32,6 @@ import org.apache.spark.network.yarn.{YarnShuffleService, YarnTestAccessor} /** * Integration test for the external shuffle service with a yarn mini-cluster */ -@ExtendedYarnTest class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { override def newYarnConfig(): YarnConfiguration = { From 841972e22c653ba58e9a65433fed203ff288f13a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 15 Sep 2015 13:33:32 -0700 Subject: [PATCH 298/802] [SPARK-10437] [SQL] Support aggregation expressions in Order By JIRA: https://issues.apache.org/jira/browse/SPARK-10437 If an expression in `SortOrder` is a resolved one, such as `count(1)`, the corresponding rule in `Analyzer` to make it work in order by will not be applied. Author: Liang-Chi Hsieh Closes #8599 from viirya/orderby-agg. --- .../sql/catalyst/analysis/Analyzer.scala | 14 +++++++++---- .../org/apache/spark/sql/SQLQuerySuite.scala | 20 +++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 591747b45c376..02f34cbf58ad0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -561,7 +561,7 @@ class Analyzer( } case sort @ Sort(sortOrder, global, aggregate: Aggregate) - if aggregate.resolved && !sort.resolved => + if aggregate.resolved => // Try resolving the ordering as though it is in the aggregate clause. try { @@ -598,9 +598,15 @@ class Analyzer( } } - Project(aggregate.output, - Sort(evaluatedOrderings, global, - aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) + // Since we don't rely on sort.resolved as the stop condition for this rule, + // we need to check this and prevent applying this rule multiple times + if (sortOrder == evaluatedOrderings) { + sort + } else { + Project(aggregate.output, + Sort(evaluatedOrderings, global, + aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown))) + } } catch { // Attempting to resolve in the aggregate can result in ambiguity. When this happens, // just return the original plan. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 962b100b532c9..f9981356f364f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1562,6 +1562,26 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { |ORDER BY sum(b) + 1 """.stripMargin), Row("4", 3) :: Row("1", 7) :: Row("3", 11) :: Row("2", 15) :: Nil) + + checkAnswer( + sql( + """ + |SELECT count(*) + |FROM orderByData + |GROUP BY a + |ORDER BY count(*) + """.stripMargin), + Row(2) :: Row(2) :: Row(2) :: Row(2) :: Nil) + + checkAnswer( + sql( + """ + |SELECT a + |FROM orderByData + |GROUP BY a + |ORDER BY a, count(*), sum(b) + """.stripMargin), + Row("1") :: Row("2") :: Row("3") :: Row("4") :: Nil) } test("SPARK-7952: fix the equality check between boolean and numeric types") { From 31a229aa739b6d05ec6d91b820fcca79b6b7d6fe Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 15 Sep 2015 13:36:52 -0700 Subject: [PATCH 299/802] [SPARK-10475] [SQL] improve column prunning for Project on Sort Sometimes we can't push down the whole `Project` though `Sort`, but we still have a chance to push down part of it. Author: Wenchen Fan Closes #8644 from cloud-fan/column-prune. --- .../sql/catalyst/optimizer/Optimizer.scala | 19 +++++++++++++++---- .../optimizer/ColumnPruningSuite.scala | 11 +++++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0f4caec7451a2..648a65e7c0eb3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -228,10 +228,21 @@ object ColumnPruning extends Rule[LogicalPlan] { case Project(projectList, Limit(exp, child)) => Limit(exp, Project(projectList, child)) - // Push down project if possible when the child is sort - case p @ Project(projectList, s @ Sort(_, _, grandChild)) - if s.references.subsetOf(p.outputSet) => - s.copy(child = Project(projectList, grandChild)) + // Push down project if possible when the child is sort. + case p @ Project(projectList, s @ Sort(_, _, grandChild)) => + if (s.references.subsetOf(p.outputSet)) { + s.copy(child = Project(projectList, grandChild)) + } else { + val neededReferences = s.references ++ p.references + if (neededReferences == grandChild.outputSet) { + // No column we can prune, return the original plan. + p + } else { + // Do not use neededReferences.toSeq directly, should respect grandChild's output order. + val newProjectList = grandChild.output.filter(neededReferences.contains) + p.copy(child = s.copy(child = Project(newProjectList, grandChild))) + } + } // Eliminate no-op Projects case Project(projectList, child) if child.output == projectList => child diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index dbebcb86809de..4a1e7ceaf394b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -80,5 +80,16 @@ class ColumnPruningSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("Column pruning for Project on Sort") { + val input = LocalRelation('a.int, 'b.string, 'c.double) + + val query = input.orderBy('b.asc).select('a).analyze + val optimized = Optimize.execute(query) + + val correctAnswer = input.select('a, 'b).orderBy('b.asc).select('a).analyze + + comparePlans(optimized, correctAnswer) + } + // todo: add more tests for column pruning } From be52faa7c72fb4b95829f09a7dc5eb5dccd03524 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 15 Sep 2015 15:46:47 -0700 Subject: [PATCH 300/802] [SPARK-7685] [ML] Apply weights to different samples in Logistic Regression In fraud detection dataset, almost all the samples are negative while only couple of them are positive. This type of high imbalanced data will bias the models toward negative resulting poor performance. In python-scikit, they provide a correction allowing users to Over-/undersample the samples of each class according to the given weights. In auto mode, selects weights inversely proportional to class frequencies in the training set. This can be done in a more efficient way by multiplying the weights into loss and gradient instead of doing actual over/undersampling in the training dataset which is very expensive. http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html On the other hand, some of the training data maybe more important like the training samples from tenure users while the training samples from new users maybe less important. We should be able to provide another "weight: Double" information in the LabeledPoint to weight them differently in the learning algorithm. Author: DB Tsai Author: DB Tsai Closes #7884 from dbtsai/SPARK-7685. --- .../classification/LogisticRegression.scala | 199 +++++++++++------- .../ml/param/shared/SharedParamsCodeGen.scala | 6 +- .../spark/ml/param/shared/sharedParams.scala | 12 +- .../stat/MultivariateOnlineSummarizer.scala | 75 ++++--- .../LogisticRegressionSuite.scala | 102 ++++++++- .../MultivariateOnlineSummarizerSuite.scala | 27 +++ project/MimaExcludes.scala | 10 +- 7 files changed, 303 insertions(+), 128 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index a460262b87e43..bd96e8d000ff2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -29,12 +29,12 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS._ -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.storage.StorageLevel /** @@ -42,7 +42,7 @@ import org.apache.spark.storage.StorageLevel */ private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol - with HasStandardization with HasThreshold { + with HasStandardization with HasWeightCol with HasThreshold { /** * Set threshold in binary classification, in range [0, 1]. @@ -146,6 +146,17 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas } } +/** + * Class that represents an instance of weighted data point with label and features. + * + * TODO: Refactor this class to proper place. + * + * @param label Label for this data point. + * @param weight The weight of this instance. + * @param features The vector of features for this data point. + */ +private[classification] case class Instance(label: Double, weight: Double, features: Vector) + /** * :: Experimental :: * Logistic regression. @@ -218,31 +229,42 @@ class LogisticRegression(override val uid: String) override def getThreshold: Double = super.getThreshold + /** + * Whether to over-/under-sample training instances according to the given weights in weightCol. + * If empty, all instances are treated equally (weight 1.0). + * Default is empty, so all instances have weight one. + * @group setParam + */ + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(weightCol -> "") + override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) override def getThresholds: Array[Double] = super.getThresholds override protected def train(dataset: DataFrame): LogisticRegressionModel = { // Extract columns from data. If dataset is persisted, do not persist oldDataset. - val instances = extractLabeledPoints(dataset).map { - case LabeledPoint(label: Double, features: Vector) => (label, features) + val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) } + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val (summarizer, labelSummarizer) = instances.treeAggregate( - (new MultivariateOnlineSummarizer, new MultiClassSummarizer))( - seqOp = (c, v) => (c, v) match { - case ((summarizer: MultivariateOnlineSummarizer, labelSummarizer: MultiClassSummarizer), - (label: Double, features: Vector)) => - (summarizer.add(features), labelSummarizer.add(label)) - }, - combOp = (c1, c2) => (c1, c2) match { - case ((summarizer1: MultivariateOnlineSummarizer, - classSummarizer1: MultiClassSummarizer), (summarizer2: MultivariateOnlineSummarizer, - classSummarizer2: MultiClassSummarizer)) => - (summarizer1.merge(summarizer2), classSummarizer1.merge(classSummarizer2)) - }) + val (summarizer, labelSummarizer) = { + val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer), + instance: Instance) => + (c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight)) + + val combOp = (c1: (MultivariateOnlineSummarizer, MultiClassSummarizer), + c2: (MultivariateOnlineSummarizer, MultiClassSummarizer)) => + (c1._1.merge(c2._1), c1._2.merge(c2._2)) + + instances.treeAggregate( + new MultivariateOnlineSummarizer, new MultiClassSummarizer)(seqOp, combOp) + } val histogram = labelSummarizer.histogram val numInvalid = labelSummarizer.countInvalid @@ -295,7 +317,7 @@ class LogisticRegression(override val uid: String) new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol)) } - val initialWeightsWithIntercept = + val initialCoefficientsWithIntercept = Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures) if ($(fitIntercept)) { @@ -312,14 +334,14 @@ class LogisticRegression(override val uid: String) b = \log{P(1) / P(0)} = \log{count_1 / count_0} }}} */ - initialWeightsWithIntercept.toArray(numFeatures) - = math.log(histogram(1).toDouble / histogram(0).toDouble) + initialCoefficientsWithIntercept.toArray(numFeatures) + = math.log(histogram(1) / histogram(0)) } val states = optimizer.iterations(new CachedDiffFunction(costFun), - initialWeightsWithIntercept.toBreeze.toDenseVector) + initialCoefficientsWithIntercept.toBreeze.toDenseVector) - val (weights, intercept, objectiveHistory) = { + val (coefficients, intercept, objectiveHistory) = { /* Note that in Logistic Regression, the objective history (loss + regularization) is log-likelihood which is invariance under feature standardization. As a result, @@ -339,28 +361,29 @@ class LogisticRegression(override val uid: String) } /* - The weights are trained in the scaled space; we're converting them back to + The coefficients are trained in the scaled space; we're converting them back to the original space. Note that the intercept in scaled space and original space is the same; as a result, no scaling is needed. */ - val rawWeights = state.x.toArray.clone() + val rawCoefficients = state.x.toArray.clone() var i = 0 while (i < numFeatures) { - rawWeights(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 } + rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 } i += 1 } if ($(fitIntercept)) { - (Vectors.dense(rawWeights.dropRight(1)).compressed, rawWeights.last, arrayBuilder.result()) + (Vectors.dense(rawCoefficients.dropRight(1)).compressed, rawCoefficients.last, + arrayBuilder.result()) } else { - (Vectors.dense(rawWeights).compressed, 0.0, arrayBuilder.result()) + (Vectors.dense(rawCoefficients).compressed, 0.0, arrayBuilder.result()) } } if (handlePersistence) instances.unpersist() - val model = copyValues(new LogisticRegressionModel(uid, weights, intercept)) + val model = copyValues(new LogisticRegressionModel(uid, coefficients, intercept)) val logRegSummary = new BinaryLogisticRegressionTrainingSummary( model.transform(dataset), $(probabilityCol), @@ -501,22 +524,29 @@ class LogisticRegressionModel private[ml] ( * corresponding joint dataset. */ private[classification] class MultiClassSummarizer extends Serializable { - private val distinctMap = new mutable.HashMap[Int, Long] + // The first element of value in distinctMap is the actually number of instances, + // and the second element of value is sum of the weights. + private val distinctMap = new mutable.HashMap[Int, (Long, Double)] private var totalInvalidCnt: Long = 0L /** * Add a new label into this MultilabelSummarizer, and update the distinct map. * @param label The label for this data point. + * @param weight The weight of this instances. * @return This MultilabelSummarizer */ - def add(label: Double): this.type = { + def add(label: Double, weight: Double = 1.0): this.type = { + require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") + + if (weight == 0.0) return this + if (label - label.toInt != 0.0 || label < 0) { totalInvalidCnt += 1 this } else { - val counts: Long = distinctMap.getOrElse(label.toInt, 0L) - distinctMap.put(label.toInt, counts + 1) + val (counts: Long, weightSum: Double) = distinctMap.getOrElse(label.toInt, (0L, 0.0)) + distinctMap.put(label.toInt, (counts + 1L, weightSum + weight)) this } } @@ -537,8 +567,8 @@ private[classification] class MultiClassSummarizer extends Serializable { } smallMap.distinctMap.foreach { case (key, value) => - val counts = largeMap.distinctMap.getOrElse(key, 0L) - largeMap.distinctMap.put(key, counts + value) + val (counts: Long, weightSum: Double) = largeMap.distinctMap.getOrElse(key, (0L, 0.0)) + largeMap.distinctMap.put(key, (counts + value._1, weightSum + value._2)) } largeMap.totalInvalidCnt += smallMap.totalInvalidCnt largeMap @@ -550,13 +580,13 @@ private[classification] class MultiClassSummarizer extends Serializable { /** @return The number of distinct labels in the input dataset. */ def numClasses: Int = distinctMap.keySet.max + 1 - /** @return The counts of each label in the input dataset. */ - def histogram: Array[Long] = { - val result = Array.ofDim[Long](numClasses) + /** @return The weightSum of each label in the input dataset. */ + def histogram: Array[Double] = { + val result = Array.ofDim[Double](numClasses) var i = 0 val len = result.length while (i < len) { - result(i) = distinctMap.getOrElse(i, 0L) + result(i) = distinctMap.getOrElse(i, (0L, 0.0))._2 i += 1 } result @@ -565,6 +595,8 @@ private[classification] class MultiClassSummarizer extends Serializable { /** * Abstraction for multinomial Logistic Regression Training results. + * Currently, the training summary ignores the training weights except + * for the objective trace. */ sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary { @@ -584,10 +616,10 @@ sealed trait LogisticRegressionSummary extends Serializable { /** Dataframe outputted by the model's `transform` method. */ def predictions: DataFrame - /** Field in "predictions" which gives the calibrated probability of each sample as a vector. */ + /** Field in "predictions" which gives the calibrated probability of each instance as a vector. */ def probabilityCol: String - /** Field in "predictions" which gives the the true label of each sample. */ + /** Field in "predictions" which gives the the true label of each instance. */ def labelCol: String } @@ -597,8 +629,8 @@ sealed trait LogisticRegressionSummary extends Serializable { * Logistic regression training results. * @param predictions dataframe outputted by the model's `transform` method. * @param probabilityCol field in "predictions" which gives the calibrated probability of - * each sample as a vector. - * @param labelCol field in "predictions" which gives the true label of each sample. + * each instance as a vector. + * @param labelCol field in "predictions" which gives the true label of each instance. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ @Experimental @@ -617,8 +649,8 @@ class BinaryLogisticRegressionTrainingSummary private[classification] ( * Binary Logistic regression results for a given model. * @param predictions dataframe outputted by the model's `transform` method. * @param probabilityCol field in "predictions" which gives the calibrated probability of - * each sample. - * @param labelCol field in "predictions" which gives the true label of each sample. + * each instance. + * @param labelCol field in "predictions" which gives the true label of each instance. */ @Experimental class BinaryLogisticRegressionSummary private[classification] ( @@ -687,14 +719,14 @@ class BinaryLogisticRegressionSummary private[classification] ( /** * LogisticAggregator computes the gradient and loss for binary logistic loss function, as used - * in binary classification for samples in sparse or dense vector in a online fashion. + * in binary classification for instances in sparse or dense vector in a online fashion. * * Note that multinomial logistic loss is not supported yet! * * Two LogisticAggregator can be merged together to have a summary of loss and gradient of * the corresponding joint dataset. * - * @param weights The weights/coefficients corresponding to the features. + * @param coefficients The coefficients corresponding to the features. * @param numClasses the number of possible outcomes for k classes classification problem in * Multinomial Logistic Regression. * @param fitIntercept Whether to fit an intercept term. @@ -702,25 +734,25 @@ class BinaryLogisticRegressionSummary private[classification] ( * @param featuresMean The mean values of the features. */ private class LogisticAggregator( - weights: Vector, + coefficients: Vector, numClasses: Int, fitIntercept: Boolean, featuresStd: Array[Double], featuresMean: Array[Double]) extends Serializable { - private var totalCnt: Long = 0L + private var weightSum = 0.0 private var lossSum = 0.0 - private val weightsArray = weights match { + private val coefficientsArray = coefficients match { case dv: DenseVector => dv.values case _ => throw new IllegalArgumentException( - s"weights only supports dense vector but got type ${weights.getClass}.") + s"coefficients only supports dense vector but got type ${coefficients.getClass}.") } - private val dim = if (fitIntercept) weightsArray.length - 1 else weightsArray.length + private val dim = if (fitIntercept) coefficientsArray.length - 1 else coefficientsArray.length - private val gradientSumArray = Array.ofDim[Double](weightsArray.length) + private val gradientSumArray = Array.ofDim[Double](coefficientsArray.length) /** * Add a new training data to this LogisticAggregator, and update the loss and gradient @@ -729,13 +761,17 @@ private class LogisticAggregator( * @param label The label for this data point. * @param data The features for one data point in dense/sparse vector format to be added * into this aggregator. + * @param weight The weight for over-/undersamples each of training instance. Default is one. * @return This LogisticAggregator object. */ - def add(label: Double, data: Vector): this.type = { - require(dim == data.size, s"Dimensions mismatch when adding new sample." + + def add(label: Double, data: Vector, weight: Double = 1.0): this.type = { + require(dim == data.size, s"Dimensions mismatch when adding new instance." + s" Expecting $dim but got ${data.size}.") + require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") - val localWeightsArray = weightsArray + if (weight == 0.0) return this + + val localCoefficientsArray = coefficientsArray val localGradientSumArray = gradientSumArray numClasses match { @@ -745,13 +781,13 @@ private class LogisticAggregator( var sum = 0.0 data.foreachActive { (index, value) => if (featuresStd(index) != 0.0 && value != 0.0) { - sum += localWeightsArray(index) * (value / featuresStd(index)) + sum += localCoefficientsArray(index) * (value / featuresStd(index)) } } - sum + { if (fitIntercept) localWeightsArray(dim) else 0.0 } + sum + { if (fitIntercept) localCoefficientsArray(dim) else 0.0 } } - val multiplier = (1.0 / (1.0 + math.exp(margin))) - label + val multiplier = weight * (1.0 / (1.0 + math.exp(margin)) - label) data.foreachActive { (index, value) => if (featuresStd(index) != 0.0 && value != 0.0) { @@ -765,15 +801,15 @@ private class LogisticAggregator( if (label > 0) { // The following is equivalent to log(1 + exp(margin)) but more numerically stable. - lossSum += MLUtils.log1pExp(margin) + lossSum += weight * MLUtils.log1pExp(margin) } else { - lossSum += MLUtils.log1pExp(margin) - margin + lossSum += weight * (MLUtils.log1pExp(margin) - margin) } case _ => new NotImplementedError("LogisticRegression with ElasticNet in ML package only supports " + "binary classification for now.") } - totalCnt += 1 + weightSum += weight this } @@ -789,8 +825,8 @@ private class LogisticAggregator( require(dim == other.dim, s"Dimensions mismatch when merging with another " + s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.") - if (other.totalCnt != 0) { - totalCnt += other.totalCnt + if (other.weightSum != 0.0) { + weightSum += other.weightSum lossSum += other.lossSum var i = 0 @@ -805,13 +841,17 @@ private class LogisticAggregator( this } - def count: Long = totalCnt - - def loss: Double = lossSum / totalCnt + def loss: Double = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but $weightSum.") + lossSum / weightSum + } def gradient: Vector = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but $weightSum.") val result = Vectors.dense(gradientSumArray.clone()) - scal(1.0 / totalCnt, result) + scal(1.0 / weightSum, result) result } } @@ -823,7 +863,7 @@ private class LogisticAggregator( * It's used in Breeze's convex optimization routines. */ private class LogisticCostFun( - data: RDD[(Double, Vector)], + data: RDD[Instance], numClasses: Int, fitIntercept: Boolean, standardization: Boolean, @@ -831,22 +871,23 @@ private class LogisticCostFun( featuresMean: Array[Double], regParamL2: Double) extends DiffFunction[BDV[Double]] { - override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = { + override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { val numFeatures = featuresStd.length - val w = Vectors.fromBreeze(weights) + val w = Vectors.fromBreeze(coefficients) - val logisticAggregator = data.treeAggregate(new LogisticAggregator(w, numClasses, fitIntercept, - featuresStd, featuresMean))( - seqOp = (c, v) => (c, v) match { - case (aggregator, (label, features)) => aggregator.add(label, features) - }, - combOp = (c1, c2) => (c1, c2) match { - case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) - }) + val logisticAggregator = { + val seqOp = (c: LogisticAggregator, instance: Instance) => + c.add(instance.label, instance.features, instance.weight) + val combOp = (c1: LogisticAggregator, c2: LogisticAggregator) => c1.merge(c2) + + data.treeAggregate( + new LogisticAggregator(w, numClasses, fitIntercept, featuresStd, featuresMean) + )(seqOp, combOp) + } val totalGradientArray = logisticAggregator.gradient.toArray - // regVal is the sum of weight squares excluding intercept for L2 regularization. + // regVal is the sum of coefficients squares excluding intercept for L2 regularization. val regVal = if (regParamL2 == 0.0) { 0.0 } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index e9e99ed1db40e..8049d51fee5ea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -42,7 +42,7 @@ private[shared] object SharedParamsCodeGen { Some("\"rawPrediction\"")), ParamDesc[String]("probabilityCol", "Column name for predicted class conditional" + " probabilities. Note: Not all models output well-calibrated probability estimates!" + - " These probabilities should be treated as confidences, not precise probabilities.", + " These probabilities should be treated as confidences, not precise probabilities", Some("\"probability\"")), ParamDesc[Double]("threshold", "threshold in binary classification prediction, in range [0, 1]", Some("0.5"), @@ -65,10 +65,10 @@ private[shared] object SharedParamsCodeGen { "options may be added later.", isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"), ParamDesc[Boolean]("standardization", "whether to standardize the training features" + - " before fitting the model.", Some("true")), + " before fitting the model", Some("true")), ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")), ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]." + - " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", + " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty", isValid = "ParamValidators.inRange(0, 1)"), ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"), ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 30092170863ad..aff47fc326c4a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -127,10 +127,10 @@ private[ml] trait HasRawPredictionCol extends Params { private[ml] trait HasProbabilityCol extends Params { /** - * Param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.. + * Param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities. * @group param */ - final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.") + final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities") setDefault(probabilityCol, "probability") @@ -270,10 +270,10 @@ private[ml] trait HasHandleInvalid extends Params { private[ml] trait HasStandardization extends Params { /** - * Param for whether to standardize the training features before fitting the model.. + * Param for whether to standardize the training features before fitting the model. * @group param */ - final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features before fitting the model.") + final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features before fitting the model") setDefault(standardization, true) @@ -304,10 +304,10 @@ private[ml] trait HasSeed extends Params { private[ml] trait HasElasticNetParam extends Params { /** - * Param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.. + * Param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. * @group param */ - final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", ParamValidators.inRange(0, 1)) + final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty", ParamValidators.inRange(0, 1)) /** @group getParam */ final def getElasticNetParam: Double = $(elasticNetParam) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 51b713e263e0c..201333c3690df 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -23,16 +23,19 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector} /** * :: DeveloperApi :: * MultivariateOnlineSummarizer implements [[MultivariateStatisticalSummary]] to compute the mean, - * variance, minimum, maximum, counts, and nonzero counts for samples in sparse or dense vector + * variance, minimum, maximum, counts, and nonzero counts for instances in sparse or dense vector * format in a online fashion. * * Two MultivariateOnlineSummarizer can be merged together to have a statistical summary of * the corresponding joint dataset. * - * A numerically stable algorithm is implemented to compute sample mean and variance: + * A numerically stable algorithm is implemented to compute the mean and variance of instances: * Reference: [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]] * Zero elements (including explicit zero values) are skipped when calling add(), * to have time complexity O(nnz) instead of O(n) for each column. + * + * For weighted instances, the unbiased estimation of variance is defined by the reliability + * weights: [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]]. */ @Since("1.1.0") @DeveloperApi @@ -44,6 +47,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S private var currM2: Array[Double] = _ private var currL1: Array[Double] = _ private var totalCnt: Long = 0 + private var weightSum: Double = 0.0 + private var weightSquareSum: Double = 0.0 private var nnz: Array[Double] = _ private var currMax: Array[Double] = _ private var currMin: Array[Double] = _ @@ -55,10 +60,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S * @return This MultivariateOnlineSummarizer object. */ @Since("1.1.0") - def add(sample: Vector): this.type = { + def add(sample: Vector): this.type = add(sample, 1.0) + + private[spark] def add(instance: Vector, weight: Double): this.type = { + require(weight >= 0.0, s"sample weight, ${weight} has to be >= 0.0") + if (weight == 0.0) return this + if (n == 0) { - require(sample.size > 0, s"Vector should have dimension larger than zero.") - n = sample.size + require(instance.size > 0, s"Vector should have dimension larger than zero.") + n = instance.size currMean = Array.ofDim[Double](n) currM2n = Array.ofDim[Double](n) @@ -69,8 +79,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currMin = Array.fill[Double](n)(Double.MaxValue) } - require(n == sample.size, s"Dimensions mismatch when adding new sample." + - s" Expecting $n but got ${sample.size}.") + require(n == instance.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $n but got ${instance.size}.") val localCurrMean = currMean val localCurrM2n = currM2n @@ -79,7 +89,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val localNnz = nnz val localCurrMax = currMax val localCurrMin = currMin - sample.foreachActive { (index, value) => + instance.foreachActive { (index, value) => if (value != 0.0) { if (localCurrMax(index) < value) { localCurrMax(index) = value @@ -90,15 +100,17 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val prevMean = localCurrMean(index) val diff = value - prevMean - localCurrMean(index) = prevMean + diff / (localNnz(index) + 1.0) - localCurrM2n(index) += (value - localCurrMean(index)) * diff - localCurrM2(index) += value * value - localCurrL1(index) += math.abs(value) + localCurrMean(index) = prevMean + weight * diff / (localNnz(index) + weight) + localCurrM2n(index) += weight * (value - localCurrMean(index)) * diff + localCurrM2(index) += weight * value * value + localCurrL1(index) += weight * math.abs(value) - localNnz(index) += 1.0 + localNnz(index) += weight } } + weightSum += weight + weightSquareSum += weight * weight totalCnt += 1 this } @@ -112,10 +124,12 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") def merge(other: MultivariateOnlineSummarizer): this.type = { - if (this.totalCnt != 0 && other.totalCnt != 0) { + if (this.weightSum != 0.0 && other.weightSum != 0.0) { require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " + s"Expecting $n but got ${other.n}.") totalCnt += other.totalCnt + weightSum += other.weightSum + weightSquareSum += other.weightSquareSum var i = 0 while (i < n) { val thisNnz = nnz(i) @@ -138,13 +152,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S nnz(i) = totalNnz i += 1 } - } else if (totalCnt == 0 && other.totalCnt != 0) { + } else if (weightSum == 0.0 && other.weightSum != 0.0) { this.n = other.n this.currMean = other.currMean.clone() this.currM2n = other.currM2n.clone() this.currM2 = other.currM2.clone() this.currL1 = other.currL1.clone() this.totalCnt = other.totalCnt + this.weightSum = other.weightSum + this.weightSquareSum = other.weightSquareSum this.nnz = other.nnz.clone() this.currMax = other.currMax.clone() this.currMin = other.currMin.clone() @@ -158,28 +174,28 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") override def mean: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") val realMean = Array.ofDim[Double](n) var i = 0 while (i < n) { - realMean(i) = currMean(i) * (nnz(i) / totalCnt) + realMean(i) = currMean(i) * (nnz(i) / weightSum) i += 1 } Vectors.dense(realMean) } /** - * Sample variance of each dimension. + * Unbiased estimate of sample variance of each dimension. * */ @Since("1.1.0") override def variance: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") val realVariance = Array.ofDim[Double](n) - val denominator = totalCnt - 1.0 + val denominator = weightSum - (weightSquareSum / weightSum) // Sample variance is computed, if the denominator is less than 0, the variance is just 0. if (denominator > 0.0) { @@ -187,9 +203,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S var i = 0 val len = currM2n.length while (i < len) { - realVariance(i) = - currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt - realVariance(i) /= denominator + realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * + (weightSum - nnz(i)) / weightSum) / denominator i += 1 } } @@ -209,7 +224,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") override def numNonzeros: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") Vectors.dense(nnz) } @@ -220,11 +235,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") override def max: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") var i = 0 while (i < n) { - if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0 + if ((nnz(i) < weightSum) && (currMax(i) < 0.0)) currMax(i) = 0.0 i += 1 } Vectors.dense(currMax) @@ -236,11 +251,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") override def min: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") var i = 0 while (i < n) { - if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0 + if ((nnz(i) < weightSum) && (currMin(i) > 0.0)) currMin(i) = 0.0 i += 1 } Vectors.dense(currMin) @@ -252,7 +267,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.2.0") override def normL2: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") val realMagnitude = Array.ofDim[Double](n) @@ -271,7 +286,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.2.0") override def normL1: Vector = { - require(totalCnt > 0, s"Nothing has been added to this summarizer.") + require(weightSum > 0, s"Nothing has been added to this summarizer.") Vectors.dense(currL1) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index cce39f382f738..f5219f9f574be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -17,11 +17,14 @@ package org.apache.spark.ml.classification +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} @@ -59,8 +62,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val testData = generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42) - sqlContext.createDataFrame( - generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42)) + sqlContext.createDataFrame(sc.parallelize(testData, 4)) } } @@ -77,6 +79,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(lr.getPredictionCol === "prediction") assert(lr.getRawPredictionCol === "rawPrediction") assert(lr.getProbabilityCol === "probability") + assert(lr.getWeightCol === "") assert(lr.getFitIntercept) assert(lr.getStandardization) val model = lr.fit(dataset) @@ -216,43 +219,65 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { test("MultiClassSummarizer") { val summarizer1 = (new MultiClassSummarizer) .add(0.0).add(3.0).add(4.0).add(3.0).add(6.0) - assert(summarizer1.histogram.zip(Array[Long](1, 0, 0, 2, 1, 0, 1)).forall(x => x._1 === x._2)) + assert(summarizer1.histogram === Array[Double](1, 0, 0, 2, 1, 0, 1)) assert(summarizer1.countInvalid === 0) assert(summarizer1.numClasses === 7) val summarizer2 = (new MultiClassSummarizer) .add(1.0).add(5.0).add(3.0).add(0.0).add(4.0).add(1.0) - assert(summarizer2.histogram.zip(Array[Long](1, 2, 0, 1, 1, 1)).forall(x => x._1 === x._2)) + assert(summarizer2.histogram === Array[Double](1, 2, 0, 1, 1, 1)) assert(summarizer2.countInvalid === 0) assert(summarizer2.numClasses === 6) val summarizer3 = (new MultiClassSummarizer) .add(0.0).add(1.3).add(5.2).add(2.5).add(2.0).add(4.0).add(4.0).add(4.0).add(1.0) - assert(summarizer3.histogram.zip(Array[Long](1, 1, 1, 0, 3)).forall(x => x._1 === x._2)) + assert(summarizer3.histogram === Array[Double](1, 1, 1, 0, 3)) assert(summarizer3.countInvalid === 3) assert(summarizer3.numClasses === 5) val summarizer4 = (new MultiClassSummarizer) .add(3.1).add(4.3).add(2.0).add(1.0).add(3.0) - assert(summarizer4.histogram.zip(Array[Long](0, 1, 1, 1)).forall(x => x._1 === x._2)) + assert(summarizer4.histogram === Array[Double](0, 1, 1, 1)) assert(summarizer4.countInvalid === 2) assert(summarizer4.numClasses === 4) // small map merges large one val summarizerA = summarizer1.merge(summarizer2) assert(summarizerA.hashCode() === summarizer2.hashCode()) - assert(summarizerA.histogram.zip(Array[Long](2, 2, 0, 3, 2, 1, 1)).forall(x => x._1 === x._2)) + assert(summarizerA.histogram === Array[Double](2, 2, 0, 3, 2, 1, 1)) assert(summarizerA.countInvalid === 0) assert(summarizerA.numClasses === 7) // large map merges small one val summarizerB = summarizer3.merge(summarizer4) assert(summarizerB.hashCode() === summarizer3.hashCode()) - assert(summarizerB.histogram.zip(Array[Long](1, 2, 2, 1, 3)).forall(x => x._1 === x._2)) + assert(summarizerB.histogram === Array[Double](1, 2, 2, 1, 3)) assert(summarizerB.countInvalid === 5) assert(summarizerB.numClasses === 5) } + test("MultiClassSummarizer with weighted samples") { + val summarizer1 = (new MultiClassSummarizer) + .add(label = 0.0, weight = 0.2).add(3.0, 0.8).add(4.0, 3.2).add(3.0, 1.3).add(6.0, 3.1) + assert(Vectors.dense(summarizer1.histogram) ~== + Vectors.dense(Array(0.2, 0, 0, 2.1, 3.2, 0, 3.1)) absTol 1E-10) + assert(summarizer1.countInvalid === 0) + assert(summarizer1.numClasses === 7) + + val summarizer2 = (new MultiClassSummarizer) + .add(1.0, 1.1).add(5.0, 2.3).add(3.0).add(0.0).add(4.0).add(1.0).add(2, 0.0) + assert(Vectors.dense(summarizer2.histogram) ~== + Vectors.dense(Array[Double](1.0, 2.1, 0.0, 1, 1, 2.3)) absTol 1E-10) + assert(summarizer2.countInvalid === 0) + assert(summarizer2.numClasses === 6) + + val summarizer = summarizer1.merge(summarizer2) + assert(Vectors.dense(summarizer.histogram) ~== + Vectors.dense(Array(1.2, 2.1, 0.0, 3.1, 4.2, 2.3, 3.1)) absTol 1E-10) + assert(summarizer.countInvalid === 0) + assert(summarizer.numClasses === 7) + } + test("binary logistic regression with intercept without regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true).setStandardization(true) val trainer2 = (new LogisticRegression).setFitIntercept(true).setStandardization(false) @@ -713,7 +738,7 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { b = \log{P(1) / P(0)} = \log{count_1 / count_0} }}} */ - val interceptTheory = math.log(histogram(1).toDouble / histogram(0).toDouble) + val interceptTheory = math.log(histogram(1) / histogram(0)) val weightsTheory = Vectors.dense(0.0, 0.0, 0.0, 0.0) assert(model1.intercept ~== interceptTheory relTol 1E-5) @@ -781,4 +806,63 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .forall(x => x(0) >= x(1))) } + + test("binary logistic regression with weighted samples") { + val (dataset, weightedDataset) = { + val nPoints = 1000 + val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + val testData = generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42) + + // Let's over-sample the positive samples twice. + val data1 = testData.flatMap { case labeledPoint: LabeledPoint => + if (labeledPoint.label == 1.0) { + Iterator(labeledPoint, labeledPoint) + } else { + Iterator(labeledPoint) + } + } + + val rnd = new Random(8392) + val data2 = testData.flatMap { case LabeledPoint(label: Double, features: Vector) => + if (rnd.nextGaussian() > 0.0) { + if (label == 1.0) { + Iterator( + Instance(label, 1.2, features), + Instance(label, 0.8, features), + Instance(0.0, 0.0, features)) + } else { + Iterator( + Instance(label, 0.3, features), + Instance(1.0, 0.0, features), + Instance(label, 0.1, features), + Instance(label, 0.6, features)) + } + } else { + if (label == 1.0) { + Iterator(Instance(label, 2.0, features)) + } else { + Iterator(Instance(label, 1.0, features)) + } + } + } + + (sqlContext.createDataFrame(sc.parallelize(data1, 4)), + sqlContext.createDataFrame(sc.parallelize(data2, 4))) + } + + val trainer1a = (new LogisticRegression).setFitIntercept(true) + .setRegParam(0.0).setStandardization(true) + val trainer1b = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight") + .setRegParam(0.0).setStandardization(true) + val model1a0 = trainer1a.fit(dataset) + val model1a1 = trainer1a.fit(weightedDataset) + val model1b = trainer1b.fit(weightedDataset) + assert(model1a0.weights !~= model1a1.weights absTol 1E-3) + assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3) + assert(model1a0.weights ~== model1b.weights absTol 1E-3) + assert(model1a0.intercept ~== model1b.intercept absTol 1E-3) + + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index 07efde4f5e6dc..b6d41db69be0a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -218,4 +218,31 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite { s0.merge(s1) assert(s0.mean(0) ~== 1.0 absTol 1e-14) } + + test("merging summarizer with weighted samples") { + val summarizer = (new MultivariateOnlineSummarizer) + .add(instance = Vectors.sparse(3, Seq((0, -0.8), (1, 1.7))), weight = 0.1) + .add(Vectors.dense(0.0, -1.2, -1.7), 0.2).merge( + (new MultivariateOnlineSummarizer) + .add(Vectors.sparse(3, Seq((0, -0.7), (1, 0.01), (2, 1.3))), 0.15) + .add(Vectors.dense(-0.5, 0.3, -1.5), 0.05)) + + assert(summarizer.count === 4) + + // The following values are hand calculated using the formula: + // [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]] + // which defines the reliability weight used for computing the unbiased estimation of variance + // for weighted instances. + assert(summarizer.mean ~== Vectors.dense(Array(-0.42, -0.107, -0.44)) + absTol 1E-10, "mean mismatch") + assert(summarizer.variance ~== Vectors.dense(Array(0.17657142857, 1.645115714, 2.42057142857)) + absTol 1E-8, "variance mismatch") + assert(summarizer.numNonzeros ~== Vectors.dense(Array(0.3, 0.5, 0.4)) + absTol 1E-10, "numNonzeros mismatch") + assert(summarizer.max ~== Vectors.dense(Array(0.0, 1.7, 1.3)) absTol 1E-10, "max mismatch") + assert(summarizer.min ~== Vectors.dense(Array(-0.8, -1.2, -1.7)) absTol 1E-10, "min mismatch") + assert(summarizer.normL2 ~== Vectors.dense(0.387298335, 0.762571308141, 0.9715966241192) + absTol 1E-8, "normL2 mismatch") + assert(summarizer.normL1 ~== Vectors.dense(0.21, 0.4265, 0.61) absTol 1E-10, "normL1 mismatch") + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 87b141cd3b058..46026c1e90ea0 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -45,7 +45,15 @@ object MimaExcludes { excludePackage("org.apache.spark.sql.execution") ) ++ MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++ - MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") + MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++ + Seq( + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.classification.LogisticCostFun.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.classification.LogisticAggregator.add"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.classification.LogisticAggregator.count") + ) case v if v.startsWith("1.5") => Seq( MimaBuild.excludeSparkPackage("network"), From b6e998634e05db0bb6267173e7b28f885c808c16 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 15 Sep 2015 16:45:47 -0700 Subject: [PATCH 301/802] [SPARK-10548] [SPARK-10563] [SQL] Fix concurrent SQL executions *Note: this is for master branch only.* The fix for branch-1.5 is at #8721. The query execution ID is currently passed from a thread to its children, which is not the intended behavior. This led to `IllegalArgumentException: spark.sql.execution.id is already set` when running queries in parallel, e.g.: ``` (1 to 100).par.foreach { _ => sc.parallelize(1 to 5).map { i => (i, i) }.toDF("a", "b").count() } ``` The cause is `SparkContext`'s local properties are inherited by default. This patch adds a way to exclude keys we don't want to be inherited, and makes SQL go through that code path. Author: Andrew Or Closes #8710 from andrewor14/concurrent-sql-executions. --- .../scala/org/apache/spark/SparkContext.scala | 9 +- .../org/apache/spark/ThreadingSuite.scala | 65 +++++------ .../sql/execution/SQLExecutionSuite.scala | 101 ++++++++++++++++++ 3 files changed, 132 insertions(+), 43 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index dee6091ce3caf..a2f34eafa2c38 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -33,6 +33,7 @@ import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} import scala.util.control.NonFatal +import org.apache.commons.lang.SerializationUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, @@ -347,8 +348,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private[spark] var checkpointDir: Option[String] = None // Thread Local variable that can be used by users to pass information down the stack - private val localProperties = new InheritableThreadLocal[Properties] { - override protected def childValue(parent: Properties): Properties = new Properties(parent) + protected[spark] val localProperties = new InheritableThreadLocal[Properties] { + override protected def childValue(parent: Properties): Properties = { + // Note: make a clone such that changes in the parent properties aren't reflected in + // the those of the children threads, which has confusing semantics (SPARK-10563). + SerializationUtils.clone(parent).asInstanceOf[Properties] + } override protected def initialValue(): Properties = new Properties() } diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala index a96a4ce201c21..54c131cdae367 100644 --- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -147,7 +147,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { }.start() } sem.acquire(2) - throwable.foreach { t => throw t } + throwable.foreach { t => throw improveStackTrace(t) } if (ThreadingSuiteState.failed.get()) { logError("Waited 1 second without seeing runningThreads = 4 (it was " + ThreadingSuiteState.runningThreads.get() + "); failing test") @@ -178,7 +178,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { threads.foreach(_.start()) sem.acquire(5) - throwable.foreach { t => throw t } + throwable.foreach { t => throw improveStackTrace(t) } assert(sc.getLocalProperty("test") === null) } @@ -207,58 +207,41 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging { threads.foreach(_.start()) sem.acquire(5) - throwable.foreach { t => throw t } + throwable.foreach { t => throw improveStackTrace(t) } assert(sc.getLocalProperty("test") === "parent") assert(sc.getLocalProperty("Foo") === null) } - test("mutations to local properties should not affect submitted jobs (SPARK-6629)") { - val jobStarted = new Semaphore(0) - val jobEnded = new Semaphore(0) - @volatile var jobResult: JobResult = null - var throwable: Option[Throwable] = None - + test("mutation in parent local property does not affect child (SPARK-10563)") { sc = new SparkContext("local", "test") - sc.setJobGroup("originalJobGroupId", "description") - sc.addSparkListener(new SparkListener { - override def onJobStart(jobStart: SparkListenerJobStart): Unit = { - jobStarted.release() - } - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { - jobResult = jobEnd.jobResult - jobEnded.release() - } - }) - - // Create a new thread which will inherit the current thread's properties - val thread = new Thread() { + val originalTestValue: String = "original-value" + var threadTestValue: String = null + sc.setLocalProperty("test", originalTestValue) + var throwable: Option[Throwable] = None + val thread = new Thread { override def run(): Unit = { try { - assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "originalJobGroupId") - // Sleeps for a total of 10 seconds, but allows cancellation to interrupt the task - try { - sc.parallelize(1 to 100).foreach { x => - Thread.sleep(100) - } - } catch { - case s: SparkException => // ignored so that we don't print noise in test logs - } + threadTestValue = sc.getLocalProperty("test") } catch { case t: Throwable => throwable = Some(t) } } } + sc.setLocalProperty("test", "this-should-not-be-inherited") thread.start() - // Wait for the job to start, then mutate the original properties, which should have been - // inherited by the running job but hopefully defensively copied or snapshotted: - jobStarted.tryAcquire(10, TimeUnit.SECONDS) - sc.setJobGroup("modifiedJobGroupId", "description") - // Canceling the original job group should cancel the running job. In other words, the - // modification of the properties object should not affect the properties of running jobs - sc.cancelJobGroup("originalJobGroupId") - jobEnded.tryAcquire(10, TimeUnit.SECONDS) - throwable.foreach { t => throw t } - assert(jobResult.isInstanceOf[JobFailed]) + thread.join() + throwable.foreach { t => throw improveStackTrace(t) } + assert(threadTestValue === originalTestValue) } + + /** + * Improve the stack trace of an error thrown from within a thread. + * Otherwise it's difficult to tell which line in the test the error came from. + */ + private def improveStackTrace(t: Throwable): Throwable = { + t.setStackTrace(t.getStackTrace ++ Thread.currentThread.getStackTrace) + t + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala new file mode 100644 index 0000000000000..63639681ef80a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.Properties + +import scala.collection.parallel.CompositeThrowable + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql.SQLContext + +class SQLExecutionSuite extends SparkFunSuite { + + test("concurrent query execution (SPARK-10548)") { + // Try to reproduce the issue with the old SparkContext + val conf = new SparkConf() + .setMaster("local[*]") + .setAppName("test") + val badSparkContext = new BadSparkContext(conf) + try { + testConcurrentQueryExecution(badSparkContext) + fail("unable to reproduce SPARK-10548") + } catch { + case e: IllegalArgumentException => + assert(e.getMessage.contains(SQLExecution.EXECUTION_ID_KEY)) + } finally { + badSparkContext.stop() + } + + // Verify that the issue is fixed with the latest SparkContext + val goodSparkContext = new SparkContext(conf) + try { + testConcurrentQueryExecution(goodSparkContext) + } finally { + goodSparkContext.stop() + } + } + + /** + * Trigger SPARK-10548 by mocking a parent and its child thread executing queries concurrently. + */ + private def testConcurrentQueryExecution(sc: SparkContext): Unit = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Initialize local properties. This is necessary for the test to pass. + sc.getLocalProperties + + // Set up a thread that runs executes a simple SQL query. + // Before starting the thread, mutate the execution ID in the parent. + // The child thread should not see the effect of this change. + var throwable: Option[Throwable] = None + val child = new Thread { + override def run(): Unit = { + try { + sc.parallelize(1 to 100).map { i => (i, i) }.toDF("a", "b").collect() + } catch { + case t: Throwable => + throwable = Some(t) + } + + } + } + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, "anything") + child.start() + child.join() + + // The throwable is thrown from the child thread so it doesn't have a helpful stack trace + throwable.foreach { t => + t.setStackTrace(t.getStackTrace ++ Thread.currentThread.getStackTrace) + throw t + } + } + +} + +/** + * A bad [[SparkContext]] that does not clone the inheritable thread local properties + * when passing them to children threads. + */ +private class BadSparkContext(conf: SparkConf) extends SparkContext(conf) { + protected[spark] override val localProperties = new InheritableThreadLocal[Properties] { + override protected def childValue(parent: Properties): Properties = new Properties(parent) + override protected def initialValue(): Properties = new Properties() + } +} From a63cdc769f511e98b38c3318bcc732c9a6c76c22 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 15 Sep 2015 16:53:27 -0700 Subject: [PATCH 302/802] [SPARK-10612] [SQL] Add prepare to LocalNode. The idea is that we should separate the function call that does memory reservation (i.e. prepare) from the function call that consumes the input (e.g. open()), so all operators can be a chance to reserve memory before they are all consumed. Author: Reynold Xin Closes #8761 from rxin/SPARK-10612. --- .../org/apache/spark/sql/execution/local/LocalNode.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index 9840080e16953..569cff565c092 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -45,6 +45,14 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging def output: Seq[Attribute] + /** + * Called before open(). Prepare can be used to reserve memory needed. It must NOT consume + * any input data. + * + * Implementations of this must also call the `prepare()` function of its children. + */ + def prepare(): Unit = children.foreach(_.prepare()) + /** * Initializes the iterator state. Must be called before calling `next()`. * From 99ecfa5945aedaa71765ecf5cce59964ae52eebe Mon Sep 17 00:00:00 2001 From: vinodkc Date: Tue, 15 Sep 2015 17:01:10 -0700 Subject: [PATCH 303/802] [SPARK-10575] [SPARK CORE] Wrapped RDD.takeSample with Scope Remove return statements in RDD.takeSample and wrap it withScope Author: vinodkc Author: vinodkc Author: Vinod K C Closes #8730 from vinodkc/fix_takesample_return. --- .../main/scala/org/apache/spark/rdd/RDD.scala | 68 +++++++++---------- 1 file changed, 31 insertions(+), 37 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 7dd2bc5d7cd72..a56e542242d5f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -469,50 +469,44 @@ abstract class RDD[T: ClassTag]( * @param seed seed for the random number generator * @return sample of specified size in an array */ - // TODO: rewrite this without return statements so we can wrap it in a scope def takeSample( withReplacement: Boolean, num: Int, - seed: Long = Utils.random.nextLong): Array[T] = { + seed: Long = Utils.random.nextLong): Array[T] = withScope { val numStDev = 10.0 - if (num < 0) { - throw new IllegalArgumentException("Negative number of elements requested") - } else if (num == 0) { - return new Array[T](0) - } - - val initialCount = this.count() - if (initialCount == 0) { - return new Array[T](0) - } - - val maxSampleSize = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt - if (num > maxSampleSize) { - throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " + - s"$numStDev * math.sqrt(Int.MaxValue)") - } - - val rand = new Random(seed) - if (!withReplacement && num >= initialCount) { - return Utils.randomizeInPlace(this.collect(), rand) - } - - val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount, - withReplacement) - - var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + require(num >= 0, "Negative number of elements requested") + require(num <= (Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt), + "Cannot support a sample size > Int.MaxValue - " + + s"$numStDev * math.sqrt(Int.MaxValue)") - // If the first sample didn't turn out large enough, keep trying to take samples; - // this shouldn't happen often because we use a big multiplier for the initial size - var numIters = 0 - while (samples.length < num) { - logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters") - samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() - numIters += 1 + if (num == 0) { + new Array[T](0) + } else { + val initialCount = this.count() + if (initialCount == 0) { + new Array[T](0) + } else { + val rand = new Random(seed) + if (!withReplacement && num >= initialCount) { + Utils.randomizeInPlace(this.collect(), rand) + } else { + val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount, + withReplacement) + var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + + // If the first sample didn't turn out large enough, keep trying to take samples; + // this shouldn't happen often because we use a big multiplier for the initial size + var numIters = 0 + while (samples.length < num) { + logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters") + samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + numIters += 1 + } + Utils.randomizeInPlace(samples, rand).take(num) + } + } } - - Utils.randomizeInPlace(samples, rand).take(num) } /** From 38700ea40cb1dd0805cc926a9e629f93c99527ad Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 15 Sep 2015 17:11:21 -0700 Subject: [PATCH 304/802] [SPARK-10381] Fix mixup of taskAttemptNumber & attemptId in OutputCommitCoordinator When speculative execution is enabled, consider a scenario where the authorized committer of a particular output partition fails during the OutputCommitter.commitTask() call. In this case, the OutputCommitCoordinator is supposed to release that committer's exclusive lock on committing once that task fails. However, due to a unit mismatch (we used task attempt number in one place and task attempt id in another) the lock will not be released, causing Spark to go into an infinite retry loop. This bug was masked by the fact that the OutputCommitCoordinator does not have enough end-to-end tests (the current tests use many mocks). Other factors contributing to this bug are the fact that we have many similarly-named identifiers that have different semantics but the same data types (e.g. attemptNumber and taskAttemptId, with inconsistent variable naming which makes them difficult to distinguish). This patch adds a regression test and fixes this bug by always using task attempt numbers throughout this code. Author: Josh Rosen Closes #8544 from JoshRosen/SPARK-10381. --- .../org/apache/spark/SparkHadoopWriter.scala | 3 +- .../org/apache/spark/TaskEndReason.scala | 7 +- .../executor/CommitDeniedException.scala | 4 +- .../spark/mapred/SparkHadoopMapRedUtil.scala | 20 ++---- .../apache/spark/scheduler/DAGScheduler.scala | 7 +- .../scheduler/OutputCommitCoordinator.scala | 48 +++++++------ .../org/apache/spark/scheduler/TaskInfo.scala | 7 +- .../status/api/v1/AllStagesResource.scala | 2 +- .../org/apache/spark/ui/jobs/StagePage.scala | 4 +- .../org/apache/spark/util/JsonProtocol.scala | 2 +- ...putCommitCoordinatorIntegrationSuite.scala | 68 +++++++++++++++++++ .../OutputCommitCoordinatorSuite.scala | 24 ++++--- .../apache/spark/util/JsonProtocolSuite.scala | 2 +- project/MimaExcludes.scala | 36 +++++++++- .../datasources/WriterContainer.scala | 3 +- .../sql/execution/ui/SQLListenerSuite.scala | 4 +- .../spark/sql/hive/hiveWriterContainers.scala | 2 +- 17 files changed, 174 insertions(+), 69 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index ae5926dd534a6..ac6eaab20d8d2 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -104,8 +104,7 @@ class SparkHadoopWriter(jobConf: JobConf) } def commit() { - SparkHadoopMapRedUtil.commitTask( - getOutputCommitter(), getTaskContext(), jobID, splitID, attemptID) + SparkHadoopMapRedUtil.commitTask(getOutputCommitter(), getTaskContext(), jobID, splitID) } def commitJob() { diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 2ae878b3e6087..7137246bc34f2 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -193,9 +193,12 @@ case object TaskKilled extends TaskFailedReason { * Task requested the driver to commit, but was denied. */ @DeveloperApi -case class TaskCommitDenied(jobID: Int, partitionID: Int, attemptID: Int) extends TaskFailedReason { +case class TaskCommitDenied( + jobID: Int, + partitionID: Int, + attemptNumber: Int) extends TaskFailedReason { override def toErrorString: String = s"TaskCommitDenied (Driver denied task commit)" + - s" for job: $jobID, partition: $partitionID, attempt: $attemptID" + s" for job: $jobID, partition: $partitionID, attemptNumber: $attemptNumber" /** * If a task failed because its attempt to commit was denied, do not count this failure * towards failing the stage. This is intended to prevent spurious stage failures in cases diff --git a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala index f47d7ef511da1..7d84889a2def0 100644 --- a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala +++ b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala @@ -26,8 +26,8 @@ private[spark] class CommitDeniedException( msg: String, jobID: Int, splitID: Int, - attemptID: Int) + attemptNumber: Int) extends Exception(msg) { - def toTaskEndReason: TaskEndReason = TaskCommitDenied(jobID, splitID, attemptID) + def toTaskEndReason: TaskEndReason = TaskCommitDenied(jobID, splitID, attemptNumber) } diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index f405b732e4725..f7298e8d5c62c 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -91,8 +91,7 @@ object SparkHadoopMapRedUtil extends Logging { committer: MapReduceOutputCommitter, mrTaskContext: MapReduceTaskAttemptContext, jobId: Int, - splitId: Int, - attemptId: Int): Unit = { + splitId: Int): Unit = { val mrTaskAttemptID = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(mrTaskContext) @@ -122,7 +121,8 @@ object SparkHadoopMapRedUtil extends Logging { if (shouldCoordinateWithDriver) { val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator - val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, attemptId) + val taskAttemptNumber = TaskContext.get().attemptNumber() + val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, taskAttemptNumber) if (canCommit) { performCommit() @@ -132,7 +132,7 @@ object SparkHadoopMapRedUtil extends Logging { logInfo(message) // We need to abort the task so that the driver can reschedule new attempts, if necessary committer.abortTask(mrTaskContext) - throw new CommitDeniedException(message, jobId, splitId, attemptId) + throw new CommitDeniedException(message, jobId, splitId, taskAttemptNumber) } } else { // Speculation is disabled or a user has chosen to manually bypass the commit coordination @@ -143,16 +143,4 @@ object SparkHadoopMapRedUtil extends Logging { logInfo(s"No need to commit output of task because needsTaskCommit=false: $mrTaskAttemptID") } } - - def commitTask( - committer: MapReduceOutputCommitter, - mrTaskContext: MapReduceTaskAttemptContext, - sparkTaskContext: TaskContext): Unit = { - commitTask( - committer, - mrTaskContext, - sparkTaskContext.stageId(), - sparkTaskContext.partitionId(), - sparkTaskContext.attemptNumber()) - } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b4f90e8347894..3c9a66e504403 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1128,8 +1128,11 @@ class DAGScheduler( val stageId = task.stageId val taskType = Utils.getFormattedClassName(task) - outputCommitCoordinator.taskCompleted(stageId, task.partitionId, - event.taskInfo.attempt, event.reason) + outputCommitCoordinator.taskCompleted( + stageId, + task.partitionId, + event.taskInfo.attemptNumber, // this is a task attempt number + event.reason) // The success case is dealt with separately below, since we need to compute accumulator // updates before posting. diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 5d926377ce86b..add0dedc03f44 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -25,7 +25,7 @@ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, RpcEndpoint private sealed trait OutputCommitCoordinationMessage extends Serializable private case object StopCoordinator extends OutputCommitCoordinationMessage -private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttempt: Long) +private case class AskPermissionToCommitOutput(stage: Int, partition: Int, attemptNumber: Int) /** * Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins" @@ -44,8 +44,8 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) var coordinatorRef: Option[RpcEndpointRef] = None private type StageId = Int - private type PartitionId = Long - private type TaskAttemptId = Long + private type PartitionId = Int + private type TaskAttemptNumber = Int /** * Map from active stages's id => partition id => task attempt with exclusive lock on committing @@ -57,7 +57,8 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance. */ private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map() - private type CommittersByStageMap = mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptId]] + private type CommittersByStageMap = + mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptNumber]] /** * Returns whether the OutputCommitCoordinator's internal data structures are all empty. @@ -75,14 +76,15 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * * @param stage the stage number * @param partition the partition number - * @param attempt a unique identifier for this task attempt + * @param attemptNumber how many times this task has been attempted + * (see [[TaskContext.attemptNumber()]]) * @return true if this task is authorized to commit, false otherwise */ def canCommit( stage: StageId, partition: PartitionId, - attempt: TaskAttemptId): Boolean = { - val msg = AskPermissionToCommitOutput(stage, partition, attempt) + attemptNumber: TaskAttemptNumber): Boolean = { + val msg = AskPermissionToCommitOutput(stage, partition, attemptNumber) coordinatorRef match { case Some(endpointRef) => endpointRef.askWithRetry[Boolean](msg) @@ -95,7 +97,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) // Called by DAGScheduler private[scheduler] def stageStart(stage: StageId): Unit = synchronized { - authorizedCommittersByStage(stage) = mutable.HashMap[PartitionId, TaskAttemptId]() + authorizedCommittersByStage(stage) = mutable.HashMap[PartitionId, TaskAttemptNumber]() } // Called by DAGScheduler @@ -107,7 +109,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) private[scheduler] def taskCompleted( stage: StageId, partition: PartitionId, - attempt: TaskAttemptId, + attemptNumber: TaskAttemptNumber, reason: TaskEndReason): Unit = synchronized { val authorizedCommitters = authorizedCommittersByStage.getOrElse(stage, { logDebug(s"Ignoring task completion for completed stage") @@ -117,12 +119,12 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) case Success => // The task output has been committed successfully case denied: TaskCommitDenied => - logInfo( - s"Task was denied committing, stage: $stage, partition: $partition, attempt: $attempt") + logInfo(s"Task was denied committing, stage: $stage, partition: $partition, " + + s"attempt: $attemptNumber") case otherReason => - if (authorizedCommitters.get(partition).exists(_ == attempt)) { - logDebug(s"Authorized committer $attempt (stage=$stage, partition=$partition) failed;" + - s" clearing lock") + if (authorizedCommitters.get(partition).exists(_ == attemptNumber)) { + logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, " + + s"partition=$partition) failed; clearing lock") authorizedCommitters.remove(partition) } } @@ -140,21 +142,23 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) private[scheduler] def handleAskPermissionToCommit( stage: StageId, partition: PartitionId, - attempt: TaskAttemptId): Boolean = synchronized { + attemptNumber: TaskAttemptNumber): Boolean = synchronized { authorizedCommittersByStage.get(stage) match { case Some(authorizedCommitters) => authorizedCommitters.get(partition) match { case Some(existingCommitter) => - logDebug(s"Denying $attempt to commit for stage=$stage, partition=$partition; " + - s"existingCommitter = $existingCommitter") + logDebug(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage, " + + s"partition=$partition; existingCommitter = $existingCommitter") false case None => - logDebug(s"Authorizing $attempt to commit for stage=$stage, partition=$partition") - authorizedCommitters(partition) = attempt + logDebug(s"Authorizing attemptNumber=$attemptNumber to commit for stage=$stage, " + + s"partition=$partition") + authorizedCommitters(partition) = attemptNumber true } case None => - logDebug(s"Stage $stage has completed, so not allowing task attempt $attempt to commit") + logDebug(s"Stage $stage has completed, so not allowing attempt number $attemptNumber of" + + s"partition $partition to commit") false } } @@ -174,9 +178,9 @@ private[spark] object OutputCommitCoordinator { } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case AskPermissionToCommitOutput(stage, partition, taskAttempt) => + case AskPermissionToCommitOutput(stage, partition, attemptNumber) => context.reply( - outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt)) + outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, attemptNumber)) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 132a9ced77700..f113c2b1b8433 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -29,7 +29,7 @@ import org.apache.spark.annotation.DeveloperApi class TaskInfo( val taskId: Long, val index: Int, - val attempt: Int, + val attemptNumber: Int, val launchTime: Long, val executorId: String, val host: String, @@ -95,7 +95,10 @@ class TaskInfo( } } - def id: String = s"$index.$attempt" + @deprecated("Use attemptNumber", "1.6.0") + def attempt: Int = attemptNumber + + def id: String = s"$index.$attemptNumber" def duration: Long = { if (!finished) { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 390c136df79b3..24a0b5220695c 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -127,7 +127,7 @@ private[v1] object AllStagesResource { new TaskData( taskId = uiData.taskInfo.taskId, index = uiData.taskInfo.index, - attempt = uiData.taskInfo.attempt, + attempt = uiData.taskInfo.attemptNumber, launchTime = new Date(uiData.taskInfo.launchTime), executorId = uiData.taskInfo.executorId, host = uiData.taskInfo.host, diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 2b71f55b7bb4f..712782d27b3cf 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -621,7 +621,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { serializationTimeProportionPos + serializationTimeProportion val index = taskInfo.index - val attempt = taskInfo.attempt + val attempt = taskInfo.attemptNumber val svgTag = if (totalExecutionTime == 0) { @@ -967,7 +967,7 @@ private[ui] class TaskDataSource( new TaskTableRowData( info.index, info.taskId, - info.attempt, + info.attemptNumber, info.speculative, info.status, info.taskLocality.toString, diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 24f78744ad74c..99614a786bd93 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -266,7 +266,7 @@ private[spark] object JsonProtocol { def taskInfoToJson(taskInfo: TaskInfo): JValue = { ("Task ID" -> taskInfo.taskId) ~ ("Index" -> taskInfo.index) ~ - ("Attempt" -> taskInfo.attempt) ~ + ("Attempt" -> taskInfo.attemptNumber) ~ ("Launch Time" -> taskInfo.launchTime) ~ ("Executor ID" -> taskInfo.executorId) ~ ("Host" -> taskInfo.host) ~ diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala new file mode 100644 index 0000000000000..1ae5b030f0832 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.hadoop.mapred.{FileOutputCommitter, TaskAttemptContext} +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.{Span, Seconds} + +import org.apache.spark.{SparkConf, SparkContext, LocalSparkContext, SparkFunSuite, TaskContext} +import org.apache.spark.util.Utils + +/** + * Integration tests for the OutputCommitCoordinator. + * + * See also: [[OutputCommitCoordinatorSuite]] for unit tests that use mocks. + */ +class OutputCommitCoordinatorIntegrationSuite + extends SparkFunSuite + with LocalSparkContext + with Timeouts { + + override def beforeAll(): Unit = { + super.beforeAll() + val conf = new SparkConf() + .set("master", "local[2,4]") + .set("spark.speculation", "true") + .set("spark.hadoop.mapred.output.committer.class", + classOf[ThrowExceptionOnFirstAttemptOutputCommitter].getCanonicalName) + sc = new SparkContext("local[2, 4]", "test", conf) + } + + test("exception thrown in OutputCommitter.commitTask()") { + // Regression test for SPARK-10381 + failAfter(Span(60, Seconds)) { + val tempDir = Utils.createTempDir() + try { + sc.parallelize(1 to 4, 2).map(_.toString).saveAsTextFile(tempDir.getAbsolutePath + "/out") + } finally { + Utils.deleteRecursively(tempDir) + } + } + } +} + +private class ThrowExceptionOnFirstAttemptOutputCommitter extends FileOutputCommitter { + override def commitTask(context: TaskAttemptContext): Unit = { + val ctx = TaskContext.get() + if (ctx.attemptNumber < 1) { + throw new java.io.FileNotFoundException("Intentional exception") + } + super.commitTask(context) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index e5ecd4b7c2610..6d08d7c5b7d2a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -63,6 +63,9 @@ import scala.language.postfixOps * was not in SparkHadoopWriter, the tests would still pass because only one of the * increments would be captured even though the commit in both tasks was executed * erroneously. + * + * See also: [[OutputCommitCoordinatorIntegrationSuite]] for integration tests that do + * not use mocks. */ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { @@ -164,27 +167,28 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { test("Only authorized committer failures can clear the authorized committer lock (SPARK-6614)") { val stage: Int = 1 - val partition: Long = 2 - val authorizedCommitter: Long = 3 - val nonAuthorizedCommitter: Long = 100 + val partition: Int = 2 + val authorizedCommitter: Int = 3 + val nonAuthorizedCommitter: Int = 100 outputCommitCoordinator.stageStart(stage) - assert(outputCommitCoordinator.canCommit(stage, partition, attempt = authorizedCommitter)) - assert(!outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter)) + + assert(outputCommitCoordinator.canCommit(stage, partition, authorizedCommitter)) + assert(!outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter)) // The non-authorized committer fails outputCommitCoordinator.taskCompleted( - stage, partition, attempt = nonAuthorizedCommitter, reason = TaskKilled) + stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled) // New tasks should still not be able to commit because the authorized committer has not failed assert( - !outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 1)) + !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 1)) // The authorized committer now fails, clearing the lock outputCommitCoordinator.taskCompleted( - stage, partition, attempt = authorizedCommitter, reason = TaskKilled) + stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled) // A new task should now be allowed to become the authorized committer assert( - outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 2)) + outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 2)) // There can only be one authorized committer assert( - !outputCommitCoordinator.canCommit(stage, partition, attempt = nonAuthorizedCommitter + 3)) + !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 3)) } } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 47e548ef0d442..143c1b901df11 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -499,7 +499,7 @@ class JsonProtocolSuite extends SparkFunSuite { private def assertEquals(info1: TaskInfo, info2: TaskInfo) { assert(info1.taskId === info2.taskId) assert(info1.index === info2.index) - assert(info1.attempt === info2.attempt) + assert(info1.attemptNumber === info2.attemptNumber) assert(info1.launchTime === info2.launchTime) assert(info1.executorId === info2.executorId) assert(info1.host === info2.host) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 46026c1e90ea0..1c96b0958586f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -45,7 +45,7 @@ object MimaExcludes { excludePackage("org.apache.spark.sql.execution") ) ++ MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++ - MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++ + MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++ Seq( ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.ml.classification.LogisticCostFun.this"), @@ -53,6 +53,23 @@ object MimaExcludes { "org.apache.spark.ml.classification.LogisticAggregator.add"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.ml.classification.LogisticAggregator.count") + ) ++ Seq( + // SPARK-10381 Fix types / units in private AskPermissionToCommitOutput RPC message. + // This class is marked as `private` but MiMa still seems to be confused by the change. + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.task"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$2"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.taskAttempt"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$3"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply") ) case v if v.startsWith("1.5") => Seq( @@ -213,6 +230,23 @@ object MimaExcludes { // SPARK-9704 Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs ProblemFilters.exclude[IncompatibleResultTypeProblem]( "org.apache.spark.mllib.linalg.VectorUDT.serialize") + ) ++ Seq( + // SPARK-10381 Fix types / units in private AskPermissionToCommitOutput RPC message. + // This class is marked as `private` but MiMa still seems to be confused by the change. + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.task"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$2"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.taskAttempt"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$3"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply") ) case v if v.startsWith("1.4") => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index f8ef674ed29c1..cfd64c1d9eb34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -198,8 +198,7 @@ private[sql] abstract class BaseWriterContainer( } def commitTask(): Unit = { - SparkHadoopMapRedUtil.commitTask( - outputCommitter, taskAttemptContext, jobId.getId, taskId.getId, taskAttemptId.getId) + SparkHadoopMapRedUtil.commitTask(outputCommitter, taskAttemptContext, jobId.getId, taskId.getId) } def abortTask(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 2bbb41ca777b7..7a46c69a056b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -54,9 +54,9 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { details = "" ) - private def createTaskInfo(taskId: Int, attempt: Int): TaskInfo = new TaskInfo( + private def createTaskInfo(taskId: Int, attemptNumber: Int): TaskInfo = new TaskInfo( taskId = taskId, - attempt = attempt, + attemptNumber = attemptNumber, // The following fields are not used in tests index = 0, launchTime = 0, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 4ca8042d22367..c8d6b718045a5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -121,7 +121,7 @@ private[hive] class SparkHiveWriterContainer( } protected def commit() { - SparkHadoopMapRedUtil.commitTask(committer, taskContext, jobID, splitID, attemptID) + SparkHadoopMapRedUtil.commitTask(committer, taskContext, jobID, splitID) } private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { From 35a19f3357d2ec017cfefb90f1018403e9617de4 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 15 Sep 2015 17:24:32 -0700 Subject: [PATCH 305/802] [SPARK-10613] [SPARK-10624] [SQL] Reduce LocalNode tests dependency on SQLContext Instead of relying on `DataFrames` to verify our answers, we can just use simple arrays. This significantly simplifies the test logic for `LocalNode`s and reduces a lot of code duplicated from `SparkPlanTest`. This also fixes an additional issue [SPARK-10624](https://issues.apache.org/jira/browse/SPARK-10624) where the output of `TakeOrderedAndProjectNode` is not actually ordered. Author: Andrew Or Closes #8764 from andrewor14/sql-local-tests-cleanup. --- .../spark/sql/execution/local/LocalNode.scala | 8 +- .../sql/execution/local/SampleNode.scala | 16 +- .../local/TakeOrderedAndProjectNode.scala | 2 +- .../spark/sql/execution/SparkPlanTest.scala | 2 +- .../spark/sql/execution/local/DummyNode.scala | 68 ++++ .../sql/execution/local/ExpandNodeSuite.scala | 54 ++- .../sql/execution/local/FilterNodeSuite.scala | 34 +- .../execution/local/HashJoinNodeSuite.scala | 141 ++++---- .../execution/local/IntersectNodeSuite.scala | 24 +- .../sql/execution/local/LimitNodeSuite.scala | 28 +- .../sql/execution/local/LocalNodeSuite.scala | 73 +--- .../sql/execution/local/LocalNodeTest.scala | 165 ++------- .../local/NestedLoopJoinNodeSuite.scala | 316 ++++++------------ .../execution/local/ProjectNodeSuite.scala | 39 ++- .../sql/execution/local/SampleNodeSuite.scala | 35 +- .../TakeOrderedAndProjectNodeSuite.scala | 50 ++- .../sql/execution/local/UnionNodeSuite.scala | 49 +-- 17 files changed, 468 insertions(+), 636 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index 569cff565c092..f96b62a67a254 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.{SQLConf, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.types.StructType /** @@ -33,18 +33,14 @@ import org.apache.spark.sql.types.StructType * Before consuming the iterator, open function must be called. * After consuming the iterator, close function must be called. */ -abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging { +abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Logging { protected val codegenEnabled: Boolean = conf.codegenEnabled protected val unsafeEnabled: Boolean = conf.unsafeEnabled - lazy val schema: StructType = StructType.fromAttributes(output) - private[this] lazy val isTesting: Boolean = sys.props.contains("spark.testing") - def output: Seq[Attribute] - /** * Called before open(). Prepare can be used to reserve memory needed. It must NOT consume * any input data. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala index abf3df1c0c2af..793700803f216 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.execution.local -import java.util.Random - import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} + /** * Sample the dataset. * @@ -51,18 +50,15 @@ case class SampleNode( override def open(): Unit = { child.open() - val (sampler, _seed) = if (withReplacement) { - val random = new Random(seed) + val sampler = + if (withReplacement) { // Disable gap sampling since the gap sampling method buffers two rows internally, // requiring us to copy the row, which is more expensive than the random number generator. - (new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false), - // Use the seed for partition 0 like PartitionwiseSampledRDD to generate the same result - // of DataFrame - random.nextLong()) + new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false) } else { - (new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed) + new BernoulliCellSampler[InternalRow](lowerBound, upperBound) } - sampler.setSeed(_seed) + sampler.setSeed(seed) iterator = sampler.sample(child.asIterator) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala index 53f1dcc65d8cf..ae672fbca8d83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala @@ -50,7 +50,7 @@ case class TakeOrderedAndProjectNode( } // Close it eagerly since we don't need it. child.close() - iterator = queue.iterator + iterator = queue.toArray.sorted(ord).iterator } override def next(): Boolean = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index de45ae4635fb7..3d218f01c9ead 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -238,7 +238,7 @@ object SparkPlanTest { outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap - plan.transformExpressions { + plan transformExpressions { case UnresolvedAttribute(Seq(u)) => inputMap.getOrElse(u, sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala new file mode 100644 index 0000000000000..efc3227dd60d8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala @@ -0,0 +1,68 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation + +/** + * A dummy [[LocalNode]] that just returns rows from a [[LocalRelation]]. + */ +private[local] case class DummyNode( + output: Seq[Attribute], + relation: LocalRelation, + conf: SQLConf) + extends LocalNode(conf) { + + import DummyNode._ + + private var index: Int = CLOSED + private val input: Seq[InternalRow] = relation.data + + def this(output: Seq[Attribute], data: Seq[Product], conf: SQLConf = new SQLConf) { + this(output, LocalRelation.fromProduct(output, data), conf) + } + + def isOpen: Boolean = index != CLOSED + + override def children: Seq[LocalNode] = Seq.empty + + override def open(): Unit = { + index = -1 + } + + override def next(): Boolean = { + index += 1 + index < input.size + } + + override def fetch(): InternalRow = { + assert(index >= 0 && index < input.size) + input(index) + } + + override def close(): Unit = { + index = CLOSED + } +} + +private object DummyNode { + val CLOSED: Int = Int.MinValue +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala index cfa7f3f6dcb97..bbd94d8da2d11 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala @@ -17,35 +17,33 @@ package org.apache.spark.sql.execution.local +import org.apache.spark.sql.catalyst.dsl.expressions._ + + class ExpandNodeSuite extends LocalNodeTest { - import testImplicits._ - - test("expand") { - val input = Seq((1, 1), (2, 2), (3, 3), (4, 4), (5, 5)).toDF("key", "value") - checkAnswer( - input, - node => - ExpandNode(conf, Seq( - Seq( - input.col("key") + input.col("value"), input.col("key") - input.col("value") - ).map(_.expr), - Seq( - input.col("key") * input.col("value"), input.col("key") / input.col("value") - ).map(_.expr) - ), node.output, node), - Seq( - (2, 0), - (1, 1), - (4, 0), - (4, 1), - (6, 0), - (9, 1), - (8, 0), - (16, 1), - (10, 0), - (25, 1) - ).toDF().collect() - ) + private def testExpand(inputData: Array[(Int, Int)] = Array.empty): Unit = { + val inputNode = new DummyNode(kvIntAttributes, inputData) + val projections = Seq(Seq('k + 'v, 'k - 'v), Seq('k * 'v, 'k / 'v)) + val expandNode = new ExpandNode(conf, projections, inputNode.output, inputNode) + val resolvedNode = resolveExpressions(expandNode) + val expectedOutput = { + val firstHalf = inputData.map { case (k, v) => (k + v, k - v) } + val secondHalf = inputData.map { case (k, v) => (k * v, k / v) } + firstHalf ++ secondHalf + } + val actualOutput = resolvedNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput.toSet === expectedOutput.toSet) + } + + test("empty") { + testExpand() } + + test("basic") { + testExpand((1 to 100).map { i => (i, i * 1000) }.toArray) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala index a12670e347c25..4eadce646d379 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala @@ -17,25 +17,29 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.catalyst.dsl.expressions._ -class FilterNodeSuite extends LocalNodeTest with SharedSQLContext { - test("basic") { - val condition = (testData.col("key") % 2) === 0 - checkAnswer( - testData, - node => FilterNode(conf, condition.expr, node), - testData.filter(condition).collect() - ) +class FilterNodeSuite extends LocalNodeTest { + + private def testFilter(inputData: Array[(Int, Int)] = Array.empty): Unit = { + val cond = 'k % 2 === 0 + val inputNode = new DummyNode(kvIntAttributes, inputData) + val filterNode = new FilterNode(conf, cond, inputNode) + val resolvedNode = resolveExpressions(filterNode) + val expectedOutput = inputData.filter { case (k, _) => k % 2 == 0 } + val actualOutput = resolvedNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } test("empty") { - val condition = (emptyTestData.col("key") % 2) === 0 - checkAnswer( - emptyTestData, - node => FilterNode(conf, condition.expr, node), - emptyTestData.filter(condition).collect() - ) + testFilter() + } + + test("basic") { + testFilter((1 to 100).map { i => (i, i) }.toArray) } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala index 78d891351f4a9..5c1bdb088eeed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -18,99 +18,80 @@ package org.apache.spark.sql.execution.local import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.execution.joins +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} + class HashJoinNodeSuite extends LocalNodeTest { - import testImplicits._ + // Test all combinations of the two dimensions: with/out unsafe and build sides + private val maybeUnsafeAndCodegen = Seq(false, true) + private val buildSides = Seq(BuildLeft, BuildRight) + maybeUnsafeAndCodegen.foreach { unsafeAndCodegen => + buildSides.foreach { buildSide => + testJoin(unsafeAndCodegen, buildSide) + } + } - def joinSuite(suiteName: String, confPairs: (String, String)*): Unit = { - test(s"$suiteName: inner join with one match per row") { - withSQLConf(confPairs: _*) { - checkAnswer2( - upperCaseData, - lowerCaseData, - wrapForUnsafe( - (node1, node2) => HashJoinNode( - conf, - Seq(upperCaseData.col("N").expr), - Seq(lowerCaseData.col("n").expr), - joins.BuildLeft, - node1, - node2) - ), - upperCaseData.join(lowerCaseData, $"n" === $"N").collect() - ) + /** + * Test inner hash join with varying degrees of matches. + */ + private def testJoin( + unsafeAndCodegen: Boolean, + buildSide: BuildSide): Unit = { + val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe" + val testNamePrefix = s"$simpleOrUnsafe / $buildSide" + val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray + val conf = new SQLConf + conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen) + conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen) + + // Actual test body + def runTest(leftInput: Array[(Int, String)], rightInput: Array[(Int, String)]): Unit = { + val rightInputMap = rightInput.toMap + val leftNode = new DummyNode(joinNameAttributes, leftInput) + val rightNode = new DummyNode(joinNicknameAttributes, rightInput) + val makeNode = (node1: LocalNode, node2: LocalNode) => { + resolveExpressions(new HashJoinNode( + conf, Seq('id1), Seq('id2), buildSide, node1, node2)) + } + val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode + val hashJoinNode = makeUnsafeNode(leftNode, rightNode) + val expectedOutput = leftInput + .filter { case (k, _) => rightInputMap.contains(k) } + .map { case (k, v) => (k, v, k, rightInputMap(k)) } + val actualOutput = hashJoinNode.collect().map { row => + // (id, name, id, nickname) + (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3)) } + assert(actualOutput === expectedOutput) } - test(s"$suiteName: inner join with multiple matches") { - withSQLConf(confPairs: _*) { - val x = testData2.where($"a" === 1).as("x") - val y = testData2.where($"a" === 1).as("y") - checkAnswer2( - x, - y, - wrapForUnsafe( - (node1, node2) => HashJoinNode( - conf, - Seq(x.col("a").expr), - Seq(y.col("a").expr), - joins.BuildLeft, - node1, - node2) - ), - x.join(y).where($"x.a" === $"y.a").collect() - ) - } + test(s"$testNamePrefix: empty") { + runTest(Array.empty, Array.empty) + runTest(someData, Array.empty) + runTest(Array.empty, someData) } - test(s"$suiteName: inner join, no matches") { - withSQLConf(confPairs: _*) { - val x = testData2.where($"a" === 1).as("x") - val y = testData2.where($"a" === 2).as("y") - checkAnswer2( - x, - y, - wrapForUnsafe( - (node1, node2) => HashJoinNode( - conf, - Seq(x.col("a").expr), - Seq(y.col("a").expr), - joins.BuildLeft, - node1, - node2) - ), - Nil - ) - } + test(s"$testNamePrefix: no matches") { + val someIrrelevantData = (10000 to 100100).map { i => (i, "piper" + i) }.toArray + runTest(someData, Array.empty) + runTest(Array.empty, someData) + runTest(someData, someIrrelevantData) + runTest(someIrrelevantData, someData) } - test(s"$suiteName: big inner join, 4 matches per row") { - withSQLConf(confPairs: _*) { - val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData) - val bigDataX = bigData.as("x") - val bigDataY = bigData.as("y") + test(s"$testNamePrefix: partial matches") { + val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray + runTest(someData, someOtherData) + runTest(someOtherData, someData) + } - checkAnswer2( - bigDataX, - bigDataY, - wrapForUnsafe( - (node1, node2) => - HashJoinNode( - conf, - Seq(bigDataX.col("key").expr), - Seq(bigDataY.col("key").expr), - joins.BuildLeft, - node1, - node2) - ), - bigDataX.join(bigDataY).where($"x.key" === $"y.key").collect()) - } + test(s"$testNamePrefix: full matches") { + val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + v) }.toArray + runTest(someData, someSuperRelevantData) + runTest(someSuperRelevantData, someData) } } - joinSuite( - "general", SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") - joinSuite("tungsten", SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala index 7deaa375fcfc2..c0ad2021b204a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala @@ -17,19 +17,21 @@ package org.apache.spark.sql.execution.local -class IntersectNodeSuite extends LocalNodeTest { - import testImplicits._ +class IntersectNodeSuite extends LocalNodeTest { test("basic") { - val input1 = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") - val input2 = (1 to 10).filter(_ % 2 == 0).map(i => (i, i.toString)).toDF("key", "value") - - checkAnswer2( - input1, - input2, - (node1, node2) => IntersectNode(conf, node1, node2), - input1.intersect(input2).collect() - ) + val n = 100 + val leftData = (1 to n).filter { i => i % 2 == 0 }.map { i => (i, i) }.toArray + val rightData = (1 to n).filter { i => i % 3 == 0 }.map { i => (i, i) }.toArray + val leftNode = new DummyNode(kvIntAttributes, leftData) + val rightNode = new DummyNode(kvIntAttributes, rightData) + val intersectNode = new IntersectNode(conf, leftNode, rightNode) + val expectedOutput = leftData.intersect(rightData) + val actualOutput = intersectNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala index 3b183902007e4..fb790636a3689 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala @@ -17,23 +17,25 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.test.SharedSQLContext -class LimitNodeSuite extends LocalNodeTest with SharedSQLContext { +class LimitNodeSuite extends LocalNodeTest { - test("basic") { - checkAnswer( - testData, - node => LimitNode(conf, 10, node), - testData.limit(10).collect() - ) + private def testLimit(inputData: Array[(Int, Int)] = Array.empty, limit: Int = 10): Unit = { + val inputNode = new DummyNode(kvIntAttributes, inputData) + val limitNode = new LimitNode(conf, limit, inputNode) + val expectedOutput = inputData.take(limit) + val actualOutput = limitNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } test("empty") { - checkAnswer( - emptyTestData, - node => LimitNode(conf, 10, node), - emptyTestData.limit(10).collect() - ) + testLimit() } + + test("basic") { + testLimit((1 to 100).map { i => (i, i) }.toArray, 20) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala index b89fa46f8b3b4..0d1ed99eec6cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala @@ -17,28 +17,24 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.IntegerType -class LocalNodeSuite extends SparkFunSuite { - private val data = (1 to 100).toArray +class LocalNodeSuite extends LocalNodeTest { + private val data = (1 to 100).map { i => (i, i) }.toArray test("basic open, next, fetch, close") { - val node = new DummyLocalNode(data) + val node = new DummyNode(kvIntAttributes, data) assert(!node.isOpen) node.open() assert(node.isOpen) - data.foreach { i => + data.foreach { case (k, v) => assert(node.next()) // fetch should be idempotent val fetched = node.fetch() assert(node.fetch() === fetched) assert(node.fetch() === fetched) - assert(node.fetch().numFields === 1) - assert(node.fetch().getInt(0) === i) + assert(node.fetch().numFields === 2) + assert(node.fetch().getInt(0) === k) + assert(node.fetch().getInt(1) === v) } assert(!node.next()) node.close() @@ -46,16 +42,17 @@ class LocalNodeSuite extends SparkFunSuite { } test("asIterator") { - val node = new DummyLocalNode(data) + val node = new DummyNode(kvIntAttributes, data) val iter = node.asIterator node.open() - data.foreach { i => + data.foreach { case (k, v) => // hasNext should be idempotent assert(iter.hasNext) assert(iter.hasNext) val item = iter.next() - assert(item.numFields === 1) - assert(item.getInt(0) === i) + assert(item.numFields === 2) + assert(item.getInt(0) === k) + assert(item.getInt(1) === v) } intercept[NoSuchElementException] { iter.next() @@ -64,53 +61,13 @@ class LocalNodeSuite extends SparkFunSuite { } test("collect") { - val node = new DummyLocalNode(data) + val node = new DummyNode(kvIntAttributes, data) node.open() val collected = node.collect() assert(collected.size === data.size) - assert(collected.forall(_.size === 1)) - assert(collected.map(_.getInt(0)) === data) + assert(collected.forall(_.size === 2)) + assert(collected.map { case row => (row.getInt(0), row.getInt(0)) } === data) node.close() } } - -/** - * A dummy [[LocalNode]] that just returns one row per integer in the input. - */ -private case class DummyLocalNode(conf: SQLConf, input: Array[Int]) extends LocalNode(conf) { - private var index = Int.MinValue - - def this(input: Array[Int]) { - this(new SQLConf, input) - } - - def isOpen: Boolean = { - index != Int.MinValue - } - - override def output: Seq[Attribute] = { - Seq(AttributeReference("something", IntegerType)()) - } - - override def children: Seq[LocalNode] = Seq.empty - - override def open(): Unit = { - index = -1 - } - - override def next(): Boolean = { - index += 1 - index < input.size - } - - override def fetch(): InternalRow = { - assert(index >= 0 && index < input.size) - val values = Array(input(index).asInstanceOf[Any]) - new GenericInternalRow(values) - } - - override def close(): Unit = { - index = Int.MinValue - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala index 86dd28064cc6a..098050bcd2236 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -17,147 +17,54 @@ package org.apache.spark.sql.execution.local -import scala.util.control.NonFatal - import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, Row, SQLConf} -import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.types.{IntegerType, StringType} -class LocalNodeTest extends SparkFunSuite with SharedSQLContext { - def conf: SQLConf = sqlContext.conf +class LocalNodeTest extends SparkFunSuite { - protected def wrapForUnsafe( - f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = { - if (conf.unsafeEnabled) { - (left: LocalNode, right: LocalNode) => { - val _left = ConvertToUnsafeNode(conf, left) - val _right = ConvertToUnsafeNode(conf, right) - val r = f(_left, _right) - ConvertToSafeNode(conf, r) - } - } else { - f - } - } - - /** - * Runs the LocalNode and makes sure the answer matches the expected result. - * @param input the input data to be used. - * @param nodeFunction a function which accepts the input LocalNode and uses it to instantiate - * the local physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. - */ - protected def checkAnswer( - input: DataFrame, - nodeFunction: LocalNode => LocalNode, - expectedAnswer: Seq[Row], - sortAnswers: Boolean = true): Unit = { - doCheckAnswer( - input :: Nil, - nodes => nodeFunction(nodes.head), - expectedAnswer, - sortAnswers) - } - - /** - * Runs the LocalNode and makes sure the answer matches the expected result. - * @param left the left input data to be used. - * @param right the right input data to be used. - * @param nodeFunction a function which accepts the input LocalNode and uses it to instantiate - * the local physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. - */ - protected def checkAnswer2( - left: DataFrame, - right: DataFrame, - nodeFunction: (LocalNode, LocalNode) => LocalNode, - expectedAnswer: Seq[Row], - sortAnswers: Boolean = true): Unit = { - doCheckAnswer( - left :: right :: Nil, - nodes => nodeFunction(nodes(0), nodes(1)), - expectedAnswer, - sortAnswers) - } + protected val conf: SQLConf = new SQLConf + protected val kvIntAttributes = Seq( + AttributeReference("k", IntegerType)(), + AttributeReference("v", IntegerType)()) + protected val joinNameAttributes = Seq( + AttributeReference("id1", IntegerType)(), + AttributeReference("name", StringType)()) + protected val joinNicknameAttributes = Seq( + AttributeReference("id2", IntegerType)(), + AttributeReference("nickname", StringType)()) /** - * Runs the `LocalNode`s and makes sure the answer matches the expected result. - * @param input the input data to be used. - * @param nodeFunction a function which accepts a sequence of input `LocalNode`s and uses them to - * instantiate the local physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. + * Wrap a function processing two [[LocalNode]]s such that: + * (1) all input rows are automatically converted to unsafe rows + * (2) all output rows are automatically converted back to safe rows */ - protected def doCheckAnswer( - input: Seq[DataFrame], - nodeFunction: Seq[LocalNode] => LocalNode, - expectedAnswer: Seq[Row], - sortAnswers: Boolean = true): Unit = { - LocalNodeTest.checkAnswer( - input.map(dataFrameToSeqScanNode), nodeFunction, expectedAnswer, sortAnswers) match { - case Some(errorMessage) => fail(errorMessage) - case None => + protected def wrapForUnsafe( + f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = { + (left: LocalNode, right: LocalNode) => { + val _left = ConvertToUnsafeNode(conf, left) + val _right = ConvertToUnsafeNode(conf, right) + val r = f(_left, _right) + ConvertToSafeNode(conf, r) } } - protected def dataFrameToSeqScanNode(df: DataFrame): SeqScanNode = { - new SeqScanNode( - conf, - df.queryExecution.sparkPlan.output, - df.queryExecution.toRdd.map(_.copy()).collect()) - } - -} - -/** - * Helper methods for writing tests of individual local physical operators. - */ -object LocalNodeTest { - /** - * Runs the `LocalNode`s and makes sure the answer matches the expected result. - * @param input the input data to be used. - * @param nodeFunction a function which accepts the input `LocalNode`s and uses them to - * instantiate the local physical operator that's being tested. - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. - * @param sortAnswers if true, the answers will be sorted by their toString representations prior - * to being compared. + * Recursively resolve all expressions in a [[LocalNode]] using the node's attributes. */ - def checkAnswer( - input: Seq[SeqScanNode], - nodeFunction: Seq[LocalNode] => LocalNode, - expectedAnswer: Seq[Row], - sortAnswers: Boolean): Option[String] = { - - val outputNode = nodeFunction(input) - - val outputResult: Seq[Row] = try { - outputNode.collect() - } catch { - case NonFatal(e) => - val errorMessage = - s""" - | Exception thrown while executing local plan: - | $outputNode - | == Exception == - | $e - | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} - """.stripMargin - return Some(errorMessage) - } - - SQLTestUtils.compareAnswers(outputResult, expectedAnswer, sortAnswers).map { errorMessage => - s""" - | Results do not match for local plan: - | $outputNode - | $errorMessage - """.stripMargin + protected def resolveExpressions(outputNode: LocalNode): LocalNode = { + outputNode transform { + case node: LocalNode => + val inputMap = node.output.map { a => (a.name, a) }.toMap + node transformExpressions { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } } } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala index b1ef26ba82f16..40299d9d5ee37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala @@ -18,222 +18,128 @@ package org.apache.spark.sql.execution.local import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} + class NestedLoopJoinNodeSuite extends LocalNodeTest { - import testImplicits._ - - private def joinSuite( - suiteName: String, buildSide: BuildSide, confPairs: (String, String)*): Unit = { - test(s"$suiteName: left outer join") { - withSQLConf(confPairs: _*) { - checkAnswer2( - upperCaseData, - lowerCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - LeftOuter, - Some((upperCaseData.col("N") === lowerCaseData.col("n")).expr)) - ), - upperCaseData.join(lowerCaseData, $"n" === $"N", "left").collect()) - - checkAnswer2( - upperCaseData, - lowerCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - LeftOuter, - Some( - (upperCaseData.col("N") === lowerCaseData.col("n") && - lowerCaseData.col("n") > 1).expr)) - ), - upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > 1, "left").collect()) - - checkAnswer2( - upperCaseData, - lowerCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - LeftOuter, - Some( - (upperCaseData.col("N") === lowerCaseData.col("n") && - upperCaseData.col("N") > 1).expr)) - ), - upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > 1, "left").collect()) - - checkAnswer2( - upperCaseData, - lowerCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - LeftOuter, - Some( - (upperCaseData.col("N") === lowerCaseData.col("n") && - lowerCaseData.col("l") > upperCaseData.col("L")).expr)) - ), - upperCaseData.join(lowerCaseData, $"n" === $"N" && $"l" > $"L", "left").collect()) + // Test all combinations of the three dimensions: with/out unsafe, build sides, and join types + private val maybeUnsafeAndCodegen = Seq(false, true) + private val buildSides = Seq(BuildLeft, BuildRight) + private val joinTypes = Seq(LeftOuter, RightOuter, FullOuter) + maybeUnsafeAndCodegen.foreach { unsafeAndCodegen => + buildSides.foreach { buildSide => + joinTypes.foreach { joinType => + testJoin(unsafeAndCodegen, buildSide, joinType) } } + } - test(s"$suiteName: right outer join") { - withSQLConf(confPairs: _*) { - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - RightOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N", "right").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - RightOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - lowerCaseData.col("n") > 1).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "right").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - RightOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - upperCaseData.col("N") > 1).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "right").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - RightOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - lowerCaseData.col("l") > upperCaseData.col("L")).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "right").collect()) + /** + * Test outer nested loop joins with varying degrees of matches. + */ + private def testJoin( + unsafeAndCodegen: Boolean, + buildSide: BuildSide, + joinType: JoinType): Unit = { + val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe" + val testNamePrefix = s"$simpleOrUnsafe / $buildSide / $joinType" + val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray + val conf = new SQLConf + conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen) + conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen) + + // Actual test body + def runTest( + joinType: JoinType, + leftInput: Array[(Int, String)], + rightInput: Array[(Int, String)]): Unit = { + val leftNode = new DummyNode(joinNameAttributes, leftInput) + val rightNode = new DummyNode(joinNicknameAttributes, rightInput) + val cond = 'id1 === 'id2 + val makeNode = (node1: LocalNode, node2: LocalNode) => { + resolveExpressions( + new NestedLoopJoinNode(conf, node1, node2, buildSide, joinType, Some(cond))) } + val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode + val hashJoinNode = makeUnsafeNode(leftNode, rightNode) + val expectedOutput = generateExpectedOutput(leftInput, rightInput, joinType) + val actualOutput = hashJoinNode.collect().map { row => + // (id, name, id, nickname) + (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3)) + } + assert(actualOutput.toSet === expectedOutput.toSet) } - test(s"$suiteName: full outer join") { - withSQLConf(confPairs: _*) { - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - FullOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N", "full").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - FullOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - lowerCaseData.col("n") > 1).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "full").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - FullOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - upperCaseData.col("N") > 1).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "full").collect()) - - checkAnswer2( - lowerCaseData, - upperCaseData, - wrapForUnsafe( - (node1, node2) => NestedLoopJoinNode( - conf, - node1, - node2, - buildSide, - FullOuter, - Some((lowerCaseData.col("n") === upperCaseData.col("N") && - lowerCaseData.col("l") > upperCaseData.col("L")).expr)) - ), - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "full").collect()) - } + test(s"$testNamePrefix: empty") { + runTest(joinType, Array.empty, Array.empty) + } + + test(s"$testNamePrefix: no matches") { + val someIrrelevantData = (10000 to 10100).map { i => (i, "piper" + i) }.toArray + runTest(joinType, someData, Array.empty) + runTest(joinType, Array.empty, someData) + runTest(joinType, someData, someIrrelevantData) + runTest(joinType, someIrrelevantData, someData) + } + + test(s"$testNamePrefix: partial matches") { + val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray + runTest(joinType, someData, someOtherData) + runTest(joinType, someOtherData, someData) + } + + test(s"$testNamePrefix: full matches") { + val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + v) } + runTest(joinType, someData, someSuperRelevantData) + runTest(joinType, someSuperRelevantData, someData) + } + } + + /** + * Helper method to generate the expected output of a test based on the join type. + */ + private def generateExpectedOutput( + leftInput: Array[(Int, String)], + rightInput: Array[(Int, String)], + joinType: JoinType): Array[(Int, String, Int, String)] = { + joinType match { + case LeftOuter => + val rightInputMap = rightInput.toMap + leftInput.map { case (k, v) => + val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0) + val rightValue = rightInputMap.getOrElse(k, null) + (k, v, rightKey, rightValue) + } + + case RightOuter => + val leftInputMap = leftInput.toMap + rightInput.map { case (k, v) => + val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0) + val leftValue = leftInputMap.getOrElse(k, null) + (leftKey, leftValue, k, v) + } + + case FullOuter => + val leftInputMap = leftInput.toMap + val rightInputMap = rightInput.toMap + val leftOutput = leftInput.map { case (k, v) => + val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0) + val rightValue = rightInputMap.getOrElse(k, null) + (k, v, rightKey, rightValue) + } + val rightOutput = rightInput.map { case (k, v) => + val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0) + val leftValue = leftInputMap.getOrElse(k, null) + (leftKey, leftValue, k, v) + } + (leftOutput ++ rightOutput).distinct + + case other => + throw new IllegalArgumentException(s"Join type $other is not applicable") } } - joinSuite( - "general-build-left", - BuildLeft, - SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") - joinSuite( - "general-build-right", - BuildRight, - SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false") - joinSuite( - "tungsten-build-left", - BuildLeft, - SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") - joinSuite( - "tungsten-build-right", - BuildRight, - SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala index 38e0a230c46d8..02ecb23d34b2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala @@ -17,28 +17,33 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NamedExpression} +import org.apache.spark.sql.types.{IntegerType, StringType} -class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext { - test("basic") { - val output = testData.queryExecution.sparkPlan.output - val columns = Seq(output(1), output(0)) - checkAnswer( - testData, - node => ProjectNode(conf, columns, node), - testData.select("value", "key").collect() - ) +class ProjectNodeSuite extends LocalNodeTest { + private val pieAttributes = Seq( + AttributeReference("id", IntegerType)(), + AttributeReference("age", IntegerType)(), + AttributeReference("name", StringType)()) + + private def testProject(inputData: Array[(Int, Int, String)] = Array.empty): Unit = { + val inputNode = new DummyNode(pieAttributes, inputData) + val columns = Seq[NamedExpression](inputNode.output(0), inputNode.output(2)) + val projectNode = new ProjectNode(conf, columns, inputNode) + val expectedOutput = inputData.map { case (id, age, name) => (id, name) } + val actualOutput = projectNode.collect().map { case row => + (row.getInt(0), row.getString(1)) + } + assert(actualOutput === expectedOutput) } test("empty") { - val output = emptyTestData.queryExecution.sparkPlan.output - val columns = Seq(output(1), output(0)) - checkAnswer( - emptyTestData, - node => ProjectNode(conf, columns, node), - emptyTestData.select("value", "key").collect() - ) + testProject() + } + + test("basic") { + testProject((1 to 100).map { i => (i, i + 1, "pie" + i) }.toArray) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala index 87a7da453999c..a3e83bbd51457 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala @@ -17,21 +17,32 @@ package org.apache.spark.sql.execution.local -class SampleNodeSuite extends LocalNodeTest { +import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} + - import testImplicits._ +class SampleNodeSuite extends LocalNodeTest { private def testSample(withReplacement: Boolean): Unit = { - test(s"withReplacement: $withReplacement") { - val seed = 0L - val input = sqlContext.sparkContext. - parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition - toDF("key", "value") - checkAnswer( - input, - node => SampleNode(conf, 0.0, 0.3, withReplacement, seed, node), - input.sample(withReplacement, 0.3, seed).collect() - ) + val seed = 0L + val lowerb = 0.0 + val upperb = 0.3 + val maybeOut = if (withReplacement) "" else "out" + test(s"with$maybeOut replacement") { + val inputData = (1 to 1000).map { i => (i, i) }.toArray + val inputNode = new DummyNode(kvIntAttributes, inputData) + val sampleNode = new SampleNode(conf, lowerb, upperb, withReplacement, seed, inputNode) + val sampler = + if (withReplacement) { + new PoissonSampler[(Int, Int)](upperb - lowerb, useGapSamplingIfPossible = false) + } else { + new BernoulliCellSampler[(Int, Int)](lowerb, upperb) + } + sampler.setSeed(seed) + val expectedOutput = sampler.sample(inputData.iterator).toArray + val actualOutput = sampleNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala index ff28b24eeff14..42ebc7bfcaadc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala @@ -17,38 +17,34 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, SortOrder} +import scala.util.Random -class TakeOrderedAndProjectNodeSuite extends LocalNodeTest { +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.SortOrder - import testImplicits._ - private def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = { - val sortOrder: Seq[SortOrder] = sortExprs.map { col => - col.expr match { - case expr: SortOrder => - expr - case expr: Expression => - SortOrder(expr, Ascending) - } - } - sortOrder - } +class TakeOrderedAndProjectNodeSuite extends LocalNodeTest { - private def testTakeOrderedAndProjectNode(desc: Boolean): Unit = { - val testCaseName = if (desc) "desc" else "asc" - test(testCaseName) { - val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value") - val sortColumn = if (desc) input.col("key").desc else input.col("key") - checkAnswer( - input, - node => TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(sortColumn), None, node), - input.sort(sortColumn).limit(5).collect() - ) + private def testTakeOrderedAndProject(desc: Boolean): Unit = { + val limit = 10 + val ascOrDesc = if (desc) "desc" else "asc" + test(ascOrDesc) { + val inputData = Random.shuffle((1 to 100).toList).map { i => (i, i) }.toArray + val inputNode = new DummyNode(kvIntAttributes, inputData) + val firstColumn = inputNode.output(0) + val sortDirection = if (desc) Descending else Ascending + val sortOrder = SortOrder(firstColumn, sortDirection) + val takeOrderAndProjectNode = new TakeOrderedAndProjectNode( + conf, limit, Seq(sortOrder), Some(Seq(firstColumn)), inputNode) + val expectedOutput = inputData + .map { case (k, _) => k } + .sortBy { k => k * (if (desc) -1 else 1) } + .take(limit) + val actualOutput = takeOrderAndProjectNode.collect().map { row => row.getInt(0) } + assert(actualOutput === expectedOutput) } } - testTakeOrderedAndProjectNode(desc = false) - testTakeOrderedAndProjectNode(desc = true) + testTakeOrderedAndProject(desc = false) + testTakeOrderedAndProject(desc = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala index eedd7320900f9..666b0235c061d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala @@ -17,36 +17,39 @@ package org.apache.spark.sql.execution.local -import org.apache.spark.sql.test.SharedSQLContext -class UnionNodeSuite extends LocalNodeTest with SharedSQLContext { +class UnionNodeSuite extends LocalNodeTest { - test("basic") { - checkAnswer2( - testData, - testData, - (node1, node2) => UnionNode(conf, Seq(node1, node2)), - testData.unionAll(testData).collect() - ) + private def testUnion(inputData: Seq[Array[(Int, Int)]]): Unit = { + val inputNodes = inputData.map { data => + new DummyNode(kvIntAttributes, data) + } + val unionNode = new UnionNode(conf, inputNodes) + val expectedOutput = inputData.flatten + val actualOutput = unionNode.collect().map { case row => + (row.getInt(0), row.getInt(1)) + } + assert(actualOutput === expectedOutput) } test("empty") { - checkAnswer2( - emptyTestData, - emptyTestData, - (node1, node2) => UnionNode(conf, Seq(node1, node2)), - emptyTestData.unionAll(emptyTestData).collect() - ) + testUnion(Seq(Array.empty)) + testUnion(Seq(Array.empty, Array.empty)) + } + + test("self") { + val data = (1 to 100).map { i => (i, i) }.toArray + testUnion(Seq(data)) + testUnion(Seq(data, data)) + testUnion(Seq(data, data, data)) } - test("complicated union") { - val dfs = Seq(testData, emptyTestData, emptyTestData, testData, testData, emptyTestData, - emptyTestData, emptyTestData, testData, emptyTestData) - doCheckAnswer( - dfs, - nodes => UnionNode(conf, nodes), - dfs.reduce(_.unionAll(_)).collect() - ) + test("basic") { + val zero = Array.empty[(Int, Int)] + val one = (1 to 100).map { i => (i, i) }.toArray + val two = (50 to 150).map { i => (i, i) }.toArray + val three = (800 to 900).map { i => (i, i) }.toArray + testUnion(Seq(zero, one, two, three)) } } From 64c29afcb787d9f176a197c25314295108ba0471 Mon Sep 17 00:00:00 2001 From: sureshthalamati Date: Tue, 15 Sep 2015 19:41:38 -0700 Subject: [PATCH 306/802] [SPARK-9078] [SQL] Allow jdbc dialects to override the query used to check the table. Current implementation uses query with a LIMIT clause to find if table already exists. This syntax works only in some database systems. This patch changes the default query to the one that is likely to work on most databases, and adds a new method to the JdbcDialect abstract class to allow dialects to override the default query. I looked at using the JDBC meta data calls, it turns out there is no common way to find the current schema, catalog..etc. There is a new method Connection.getSchema() , but that is available only starting jdk1.7 , and existing jdbc drivers may not have implemented it. Other option was to use jdbc escape syntax clause for LIMIT, not sure on how well this supported in all the databases also. After looking at all the jdbc metadata options my conclusion was most common way is to use the simple select query with 'where 1 =0' , and allow dialects to customize as needed Author: sureshthalamati Closes #8676 from sureshthalamati/table_exists_spark-9078. --- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../datasources/jdbc/JdbcUtils.scala | 9 ++++++--- .../apache/spark/sql/jdbc/JdbcDialects.scala | 20 +++++++++++++++++++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 14 +++++++++++++ 4 files changed, 41 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index b2a66dd417b4c..745bb4ec9cf1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -255,7 +255,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { val conn = JdbcUtils.createConnection(url, props) try { - var tableExists = JdbcUtils.tableExists(conn, table) + var tableExists = JdbcUtils.tableExists(conn, url, table) if (mode == SaveMode.Ignore && tableExists) { return diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 26788b2a4fd69..f89d55b20e212 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -42,10 +42,13 @@ object JdbcUtils extends Logging { /** * Returns true if the table already exists in the JDBC database. */ - def tableExists(conn: Connection, table: String): Boolean = { + def tableExists(conn: Connection, url: String, table: String): Boolean = { + val dialect = JdbcDialects.get(url) + // Somewhat hacky, but there isn't a good way to identify whether a table exists for all - // SQL database systems, considering "table" could also include the database name. - Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess + // SQL database systems using JDBC meta data calls, considering "table" could also include + // the database name. Query used to find table exists can be overriden by the dialects. + Try(conn.prepareStatement(dialect.getTableExistsQuery(table)).executeQuery()).isSuccess } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index c6d05c9b83b98..68ebaaca6c53d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -88,6 +88,17 @@ abstract class JdbcDialect { def quoteIdentifier(colName: String): String = { s""""$colName"""" } + + /** + * Get the SQL query that should be used to find if the given table exists. Dialects can + * override this method to return a query that works best in a particular database. + * @param table The name of the table. + * @return The SQL query to use for checking the table. + */ + def getTableExistsQuery(table: String): String = { + s"SELECT * FROM $table WHERE 1=0" + } + } /** @@ -198,6 +209,11 @@ case object PostgresDialect extends JdbcDialect { case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) case _ => None } + + override def getTableExistsQuery(table: String): String = { + s"SELECT 1 FROM $table LIMIT 1" + } + } /** @@ -222,6 +238,10 @@ case object MySQLDialect extends JdbcDialect { override def quoteIdentifier(colName: String): String = { s"`$colName`" } + + override def getTableExistsQuery(table: String): String = { + s"SELECT 1 FROM $table LIMIT 1" + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index ed710689cc670..5ab9381de4d66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -450,4 +450,18 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(db2Dialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB") assert(db2Dialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "CHAR(1)") } + + test("table exists query by jdbc dialect") { + val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") + val h2 = JdbcDialects.get(url) + val table = "weblogs" + val defaultQuery = s"SELECT * FROM $table WHERE 1=0" + val limitQuery = s"SELECT 1 FROM $table LIMIT 1" + assert(MySQL.getTableExistsQuery(table) == limitQuery) + assert(Postgres.getTableExistsQuery(table) == limitQuery) + assert(db2.getTableExistsQuery(table) == defaultQuery) + assert(h2.getTableExistsQuery(table) == defaultQuery) + } } From b921fe4dc0442aa133ab7d55fba24bc798d59aa2 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 15 Sep 2015 19:43:26 -0700 Subject: [PATCH 307/802] [SPARK-10595] [ML] [MLLIB] [DOCS] Various ML guide cleanups MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Various ML guide cleanups. * ml-guide.md: Make it easier to access the algorithm-specific guides. * LDA user guide: EM often begins with useless topics, but running longer generally improves them dramatically. E.g., 10 iterations on a Wikipedia dataset produces useless topics, but 50 iterations produces very meaningful topics. * mllib-feature-extraction.html#elementwiseproduct: “w” parameter should be “scalingVec” * Clean up Binarizer user guide a little. * Document in Pipeline that users should not put an instance into the Pipeline in more than 1 place. * spark.ml Word2Vec user guide: clean up grammar/writing * Chi Sq Feature Selector docs: Improve text in doc. CC: mengxr feynmanliang Author: Joseph K. Bradley Closes #8752 from jkbradley/mlguide-fixes-1.5. --- docs/ml-features.md | 34 +++++++++++++++++--- docs/ml-guide.md | 31 ++++++++++++------- docs/mllib-clustering.md | 4 +++ docs/mllib-feature-extraction.md | 53 +++++++++++++++++++++----------- docs/mllib-guide.md | 4 +-- 5 files changed, 91 insertions(+), 35 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index a414c21b5c280..b70da4ac63845 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -123,12 +123,21 @@ for features_label in rescaledData.select("features", "label").take(3): ## Word2Vec -`Word2Vec` is an `Estimator` which takes sequences of words that represents documents and trains a `Word2VecModel`. The model is a `Map(String, Vector)` essentially, which maps each word to an unique fix-sized vector. The `Word2VecModel` transforms each documents into a vector using the average of all words in the document, which aims to other computations of documents such as similarity calculation consequencely. Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#Word2Vec) for more details on Word2Vec. +`Word2Vec` is an `Estimator` which takes sequences of words representing documents and trains a +`Word2VecModel`. The model maps each word to a unique fixed-size vector. The `Word2VecModel` +transforms each document into a vector using the average of all words in the document; this vector +can then be used for as features for prediction, document similarity calculations, etc. +Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#Word2Vec) for more +details. -Word2Vec is implemented in [Word2Vec](api/scala/index.html#org.apache.spark.ml.feature.Word2Vec). In the following code segment, we start with a set of documents, each of them is represented as a sequence of words. For each document, we transform it into a feature vector. This feature vector could then be passed to a learning algorithm. +In the following code segment, we start with a set of documents, each of which is represented as a sequence of words. For each document, we transform it into a feature vector. This feature vector could then be passed to a learning algorithm.
+ +Refer to the [Word2Vec Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Word2Vec) +for more details on the API. + {% highlight scala %} import org.apache.spark.ml.feature.Word2Vec @@ -152,6 +161,10 @@ result.select("result").take(3).foreach(println)
+ +Refer to the [Word2Vec Java docs](api/java/org/apache/spark/ml/feature/Word2Vec.html) +for more details on the API. + {% highlight java %} import java.util.Arrays; @@ -192,6 +205,10 @@ for (Row r: result.select("result").take(3)) {
+ +Refer to the [Word2Vec Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.Word2Vec) +for more details on the API. + {% highlight python %} from pyspark.ml.feature import Word2Vec @@ -621,12 +638,15 @@ for ngrams_label in ngramDataFrame.select("ngrams", "label").take(3): ## Binarizer -Binarization is the process of thresholding numerical features to binary features. As some probabilistic estimators make assumption that the input data is distributed according to [Bernoulli distribution](http://en.wikipedia.org/wiki/Bernoulli_distribution), a binarizer is useful for pre-processing the input data with continuous numerical features. +Binarization is the process of thresholding numerical features to binary (0/1) features. -A simple [Binarizer](api/scala/index.html#org.apache.spark.ml.feature.Binarizer) class provides this functionality. Besides the common parameters of `inputCol` and `outputCol`, `Binarizer` has the parameter `threshold` used for binarizing continuous numerical features. The features greater than the threshold, will be binarized to 1.0. The features equal to or less than the threshold, will be binarized to 0.0. The example below shows how to binarize numerical features. +`Binarizer` takes the common parameters `inputCol` and `outputCol`, as well as the `threshold` for binarization. Feature values greater than the threshold are binarized to 1.0; values equal to or less than the threshold are binarized to 0.0.
+ +Refer to the [Binarizer API doc](api/scala/index.html#org.apache.spark.ml.feature.Binarizer) for more details. + {% highlight scala %} import org.apache.spark.ml.feature.Binarizer import org.apache.spark.sql.DataFrame @@ -650,6 +670,9 @@ binarizedFeatures.collect().foreach(println)
+ +Refer to the [Binarizer API doc](api/java/org/apache/spark/ml/feature/Binarizer.html) for more details. + {% highlight java %} import java.util.Arrays; @@ -687,6 +710,9 @@ for (Row r : binarizedFeatures.collect()) {
+ +Refer to the [Binarizer API doc](api/python/pyspark.ml.html#pyspark.ml.feature.Binarizer) for more details. + {% highlight python %} from pyspark.ml.feature import Binarizer diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 78c93a95c7807..c5d7f990021f1 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -32,7 +32,21 @@ See the [algorithm guides](#algorithm-guides) section below for guides on sub-pa * This will become a table of contents (this text will be scraped). {:toc} -# Main concepts +# Algorithm guides + +We provide several algorithm guides specific to the Pipelines API. +Several of these algorithms, such as certain feature transformers, are not in the `spark.mllib` API. +Also, some algorithms have additional capabilities in the `spark.ml` API; e.g., random forests +provide class probabilities, and linear models provide model summaries. + +* [Feature extraction, transformation, and selection](ml-features.html) +* [Decision Trees for classification and regression](ml-decision-tree.html) +* [Ensembles](ml-ensembles.html) +* [Linear methods with elastic net regularization](ml-linear-methods.html) +* [Multilayer perceptron classifier](ml-ann.html) + + +# Main concepts in Pipelines Spark ML standardizes APIs for machine learning algorithms to make it easier to combine multiple algorithms into a single pipeline, or workflow. @@ -166,6 +180,11 @@ compile-time type checking. `Pipeline`s and `PipelineModel`s instead do runtime checking before actually running the `Pipeline`. This type checking is done using the `DataFrame` *schema*, a description of the data types of columns in the `DataFrame`. +*Unique Pipeline stages*: A `Pipeline`'s stages should be unique instances. E.g., the same instance +`myHashingTF` should not be inserted into the `Pipeline` twice since `Pipeline` stages must have +unique IDs. However, different instances `myHashingTF1` and `myHashingTF2` (both of type `HashingTF`) +can be put into the same `Pipeline` since different instances will be created with different IDs. + ## Parameters Spark ML `Estimator`s and `Transformer`s use a uniform API for specifying parameters. @@ -184,16 +203,6 @@ Parameters belong to specific instances of `Estimator`s and `Transformer`s. For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. -# Algorithm guides - -There are now several algorithms in the Pipelines API which are not in the `spark.mllib` API, so we link to documentation for them here. These algorithms are mostly feature transformers, which fit naturally into the `Transformer` abstraction in Pipelines, and ensembles, which fit naturally into the `Estimator` abstraction in the Pipelines. - -* [Feature extraction, transformation, and selection](ml-features.html) -* [Decision Trees for classification and regression](ml-decision-tree.html) -* [Ensembles](ml-ensembles.html) -* [Linear methods with elastic net regularization](ml-linear-methods.html) -* [Multilayer perceptron classifier](ml-ann.html) - # Code examples This section gives code examples illustrating the functionality discussed above. diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 3fb35d3c50b06..c2711cf82deb4 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -507,6 +507,10 @@ must also be $> 1.0$. Providing `Vector(-1)` results in default behavior $> 1.0$. Providing `-1` results in defaulting to a value of $0.1 + 1$. * `maxIterations`: The maximum number of EM iterations. +*Note*: It is important to do enough iterations. In early iterations, EM often has useless topics, +but those topics improve dramatically after more iterations. Using at least 20 and possibly +50-100 iterations is often reasonable, depending on your dataset. + `EMLDAOptimizer` produces a `DistributedLDAModel`, which stores not only the inferred topics but also the full training corpus and topic distributions for each document in the training corpus. A diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index de86aba2ae627..7e417ed5f37a9 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -380,35 +380,43 @@ data2 = labels.zip(normalizer2.transform(features))
-## Feature selection -[Feature selection](http://en.wikipedia.org/wiki/Feature_selection) allows selecting the most relevant features for use in model construction. Feature selection reduces the size of the vector space and, in turn, the complexity of any subsequent operation with vectors. The number of features to select can be tuned using a held-out validation set. +## ChiSqSelector -### ChiSqSelector -[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) stands for Chi-Squared feature selection. It operates on labeled data with categorical features. `ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, and then filters (selects) the top features which the class label depends on the most. This is akin to yielding the features with the most predictive power. +[Feature selection](http://en.wikipedia.org/wiki/Feature_selection) tries to identify relevant +features for use in model construction. It reduces the size of the feature space, which can improve +both speed and statistical learning behavior. -#### Model Fitting +[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) implements +Chi-Squared feature selection. It operates on labeled data with categorical features. +`ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, +and then filters (selects) the top features which the class label depends on the most. +This is akin to yielding the features with the most predictive power. -[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) has the -following parameters in the constructor: +The number of features to select can be tuned using a held-out validation set. -* `numTopFeatures` number of top features that the selector will select (filter). +### Model Fitting -We provide a [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) method in -`ChiSqSelector` which can take an input of `RDD[LabeledPoint]` with categorical features, learn the summary statistics, and then -return a `ChiSqSelectorModel` which can transform an input dataset into the reduced feature space. +`ChiSqSelector` takes a `numTopFeatures` parameter specifying the number of top features that +the selector will select. -This model implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer) -which can apply the Chi-Squared feature selection on a `Vector` to produce a reduced `Vector` or on +The [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) method takes +an input of `RDD[LabeledPoint]` with categorical features, learns the summary statistics, and then +returns a `ChiSqSelectorModel` which can transform an input dataset into the reduced feature space. +The `ChiSqSelectorModel` can be applied either to a `Vector` to produce a reduced `Vector`, or to an `RDD[Vector]` to produce a reduced `RDD[Vector]`. Note that the user can also construct a `ChiSqSelectorModel` by hand by providing an array of selected feature indices (which must be sorted in ascending order). -#### Example +### Example The following example shows the basic use of ChiSqSelector. The data set used has a feature matrix consisting of greyscale values that vary from 0 to 255 for each feature.
-
+
+ +Refer to the [`ChiSqSelector` Scala docs](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) +for details on the API. + {% highlight scala %} import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vectors @@ -434,7 +442,11 @@ val filteredData = discretizedData.map { lp => {% endhighlight %}
-
+
+ +Refer to the [`ChiSqSelector` Java docs](api/java/org/apache/spark/mllib/feature/ChiSqSelector.html) +for details on the API. + {% highlight java %} import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; @@ -486,7 +498,12 @@ sc.stop(); ## ElementwiseProduct -ElementwiseProduct multiplies each input vector by a provided "weight" vector, using element-wise multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29) between the input vector, `v` and transforming vector, `w`, to yield a result vector. +`ElementwiseProduct` multiplies each input vector by a provided "weight" vector, using element-wise +multiplication. In other words, it scales each column of the dataset by a scalar multiplier. This +represents the [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_%28matrices%29) +between the input vector, `v` and transforming vector, `scalingVec`, to yield a result vector. +Qu8T948*1# +Denoting the `scalingVec` as "`w`," this transformation may be written as: `\[ \begin{pmatrix} v_1 \\ @@ -506,7 +523,7 @@ v_N [`ElementwiseProduct`](api/scala/index.html#org.apache.spark.mllib.feature.ElementwiseProduct) has the following parameter in the constructor: -* `w`: the transforming vector. +* `scalingVec`: the transforming vector. `ElementwiseProduct` implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer) which can apply the weighting on a `Vector` to produce a transformed `Vector` or on an `RDD[Vector]` to produce a transformed `RDD[Vector]`. diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 257f7cc7603fa..91e50ccfecec4 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -13,9 +13,9 @@ primitives and higher-level pipeline APIs. It divides into two packages: -* [`spark.mllib`](mllib-guide.html#mllib-types-algorithms-and-utilities) contains the original API +* [`spark.mllib`](mllib-guide.html#data-types-algorithms-and-utilities) contains the original API built on top of [RDDs](programming-guide.html#resilient-distributed-datasets-rdds). -* [`spark.ml`](mllib-guide.html#sparkml-high-level-apis-for-ml-pipelines) provides higher-level API +* [`spark.ml`](ml-guide.html) provides higher-level API built on top of [DataFrames](sql-programming-guide.html#dataframes) for constructing ML pipelines. Using `spark.ml` is recommended because with DataFrames the API is more versatile and flexible. From 95b6a8103fb527f501ca26b1d6e3a5859970a1e2 Mon Sep 17 00:00:00 2001 From: Vinod K C Date: Tue, 15 Sep 2015 23:25:51 -0700 Subject: [PATCH 308/802] [SPARK-10516] [ MLLIB] Added values property in DenseVector Author: Vinod K C Closes #8682 from vinodkc/fix_SPARK-10516. --- python/pyspark/mllib/linalg/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 380f86e9b44f8..4829acb16ed8a 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -399,6 +399,10 @@ def squared_distance(self, other): def toArray(self): return self.array + @property + def values(self): + return self.array + def __getitem__(self, item): return self.array[item] From 1894653edce718e874d1ddc9ba442bce43cbc082 Mon Sep 17 00:00:00 2001 From: Luciano Resende Date: Wed, 16 Sep 2015 10:47:30 +0100 Subject: [PATCH 309/802] [SPARK-10511] [BUILD] Reset git repository before packaging source distro The calculation of Spark version is downloading Scala and Zinc in the build directory which is inflating the size of the source distribution. Reseting the repo before packaging the source distribution fix this issue. Author: Luciano Resende Closes #8774 from lresende/spark-10511. --- dev/create-release/release-build.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index d0b3a54dde1dc..9dac43ce54425 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -99,6 +99,7 @@ fi DEST_DIR_NAME="spark-$SPARK_PACKAGE_VERSION" USER_HOST="$ASF_USERNAME@people.apache.org" +git clean -d -f -x rm .gitignore rm -rf .git cd .. From d9b7f3e4dbceb91ea4d1a1fed3ab847335f8588b Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Wed, 16 Sep 2015 04:34:14 -0700 Subject: [PATCH 310/802] [SPARK-10276] [MLLIB] [PYSPARK] Add @since annotation to pyspark.mllib.recommendation Author: Yu ISHIKAWA Closes #8677 from yu-iskw/SPARK-10276. --- python/pyspark/mllib/recommendation.py | 36 +++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 506ca2151cce7..95047b5b7b4b7 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -18,7 +18,7 @@ import array from collections import namedtuple -from pyspark import SparkContext +from pyspark import SparkContext, since from pyspark.rdd import RDD from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc from pyspark.mllib.util import JavaLoader, JavaSaveable @@ -36,6 +36,8 @@ class Rating(namedtuple("Rating", ["user", "product", "rating"])): (1, 2, 5.0) >>> (r[0], r[1], r[2]) (1, 2, 5.0) + + .. versionadded:: 1.2.0 """ def __reduce__(self): @@ -111,13 +113,17 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): ... rmtree(path) ... except OSError: ... pass + + .. versionadded:: 0.9.0 """ + @since("0.9.0") def predict(self, user, product): """ Predicts rating for the given user and product. """ return self._java_model.predict(int(user), int(product)) + @since("0.9.0") def predictAll(self, user_product): """ Returns a list of predicted ratings for input user and product pairs. @@ -128,6 +134,7 @@ def predictAll(self, user_product): user_product = user_product.map(lambda u_p: (int(u_p[0]), int(u_p[1]))) return self.call("predict", user_product) + @since("1.2.0") def userFeatures(self): """ Returns a paired RDD, where the first element is the user and the @@ -135,6 +142,7 @@ def userFeatures(self): """ return self.call("getUserFeatures").mapValues(lambda v: array.array('d', v)) + @since("1.2.0") def productFeatures(self): """ Returns a paired RDD, where the first element is the product and the @@ -142,6 +150,7 @@ def productFeatures(self): """ return self.call("getProductFeatures").mapValues(lambda v: array.array('d', v)) + @since("1.4.0") def recommendUsers(self, product, num): """ Recommends the top "num" number of users for a given product and returns a list @@ -149,6 +158,7 @@ def recommendUsers(self, product, num): """ return list(self.call("recommendUsers", product, num)) + @since("1.4.0") def recommendProducts(self, user, num): """ Recommends the top "num" number of products for a given user and returns a list @@ -157,17 +167,25 @@ def recommendProducts(self, user, num): return list(self.call("recommendProducts", user, num)) @property + @since("1.4.0") def rank(self): + """Rank for the features in this model""" return self.call("rank") @classmethod + @since("1.3.1") def load(cls, sc, path): + """Load a model from the given path""" model = cls._load_java(sc, path) wrapper = sc._jvm.MatrixFactorizationModelWrapper(model) return MatrixFactorizationModel(wrapper) class ALS(object): + """Alternating Least Squares matrix factorization + + .. versionadded:: 0.9.0 + """ @classmethod def _prepare(cls, ratings): @@ -188,15 +206,31 @@ def _prepare(cls, ratings): return ratings @classmethod + @since("0.9.0") def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative=False, seed=None): + """ + Train a matrix factorization model given an RDD of ratings given by users to some products, + in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the + product of two lower-rank matrices of a given rank (number of features). To solve for these + features, we run a given number of iterations of ALS. This is done using a level of + parallelism given by `blocks`. + """ model = callMLlibFunc("trainALSModel", cls._prepare(ratings), rank, iterations, lambda_, blocks, nonnegative, seed) return MatrixFactorizationModel(model) @classmethod + @since("0.9.0") def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01, nonnegative=False, seed=None): + """ + Train a matrix factorization model given an RDD of 'implicit preferences' given by users + to some products, in the form of (userID, productID, preference) pairs. We approximate the + ratings matrix as the product of two lower-rank matrices of a given rank (number of + features). To solve for these features, we run a given number of iterations of ALS. + This is done using a level of parallelism given by `blocks`. + """ model = callMLlibFunc("trainImplicitALSModel", cls._prepare(ratings), rank, iterations, lambda_, blocks, alpha, nonnegative, seed) return MatrixFactorizationModel(model) From 5dbaf3d3911bbfa003bc75459aaad66b4f6e0c67 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 16 Sep 2015 19:19:23 +0100 Subject: [PATCH 311/802] [SPARK-10589] [WEBUI] Add defense against external site framing Set `X-Frame-Options: SAMEORIGIN` to protect against frame-related vulnerability Author: Sean Owen Closes #8745 from srowen/SPARK-10589. --- .../spark/deploy/worker/ui/WorkerWebUI.scala | 7 ++++--- .../org/apache/spark/metrics/MetricsSystem.scala | 2 +- .../spark/metrics/sink/MetricsServlet.scala | 6 +++--- .../scala/org/apache/spark/ui/JettyUtils.scala | 16 ++++++++++++++-- .../main/scala/org/apache/spark/ui/WebUI.scala | 4 ++-- 5 files changed, 24 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 709a27233598c..1a0598e50dcf1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -20,9 +20,8 @@ package org.apache.spark.deploy.worker.ui import java.io.File import javax.servlet.http.HttpServletRequest -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.Logging import org.apache.spark.deploy.worker.Worker -import org.apache.spark.deploy.worker.ui.WorkerWebUI._ import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.RpcUtils @@ -49,7 +48,9 @@ class WorkerWebUI( attachPage(new WorkerPage(this)) attachHandler(createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static")) attachHandler(createServletHandler("/log", - (request: HttpServletRequest) => logPage.renderLog(request), worker.securityMgr)) + (request: HttpServletRequest) => logPage.renderLog(request), + worker.securityMgr, + worker.conf)) } } diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 4517f465ebd3b..48afe3ae3511f 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -88,7 +88,7 @@ private[spark] class MetricsSystem private ( */ def getServletHandlers: Array[ServletContextHandler] = { require(running, "Can only call getServletHandlers on a running MetricsSystem") - metricsServlet.map(_.getHandlers).getOrElse(Array()) + metricsServlet.map(_.getHandlers(conf)).getOrElse(Array()) } metricsConfig.initialize() diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 0c2e212a33074..4193e1d21d3c1 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -27,7 +27,7 @@ import com.codahale.metrics.json.MetricsModule import com.fasterxml.jackson.databind.ObjectMapper import org.eclipse.jetty.servlet.ServletContextHandler -import org.apache.spark.SecurityManager +import org.apache.spark.{SparkConf, SecurityManager} import org.apache.spark.ui.JettyUtils._ private[spark] class MetricsServlet( @@ -49,10 +49,10 @@ private[spark] class MetricsServlet( val mapper = new ObjectMapper().registerModule( new MetricsModule(TimeUnit.SECONDS, TimeUnit.MILLISECONDS, servletShowSample)) - def getHandlers: Array[ServletContextHandler] = { + def getHandlers(conf: SparkConf): Array[ServletContextHandler] = { Array[ServletContextHandler]( createServletHandler(servletPath, - new ServletParams(request => getMetricsSnapshot(request), "text/json"), securityMgr) + new ServletParams(request => getMetricsSnapshot(request), "text/json"), securityMgr, conf) ) } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 779c0ba083596..b796a44fe01ac 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -59,7 +59,17 @@ private[spark] object JettyUtils extends Logging { def createServlet[T <% AnyRef]( servletParams: ServletParams[T], - securityMgr: SecurityManager): HttpServlet = { + securityMgr: SecurityManager, + conf: SparkConf): HttpServlet = { + + // SPARK-10589 avoid frame-related click-jacking vulnerability, using X-Frame-Options + // (see http://tools.ietf.org/html/rfc7034). By default allow framing only from the + // same origin, but allow framing for a specific named URI. + // Example: spark.ui.allowFramingFrom = https://example.com/ + val allowFramingFrom = conf.getOption("spark.ui.allowFramingFrom") + val xFrameOptionsValue = + allowFramingFrom.map(uri => s"ALLOW-FROM $uri").getOrElse("SAMEORIGIN") + new HttpServlet { override def doGet(request: HttpServletRequest, response: HttpServletResponse) { try { @@ -68,6 +78,7 @@ private[spark] object JettyUtils extends Logging { response.setStatus(HttpServletResponse.SC_OK) val result = servletParams.responder(request) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + response.setHeader("X-Frame-Options", xFrameOptionsValue) // scalastyle:off println response.getWriter.println(servletParams.extractFn(result)) // scalastyle:on println @@ -97,8 +108,9 @@ private[spark] object JettyUtils extends Logging { path: String, servletParams: ServletParams[T], securityMgr: SecurityManager, + conf: SparkConf, basePath: String = ""): ServletContextHandler = { - createServletHandler(path, createServlet(servletParams, securityMgr), basePath) + createServletHandler(path, createServlet(servletParams, securityMgr, conf), basePath) } /** Create a context handler that responds to a request with the given path prefix */ diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 61449847add3d..81a121fd441bd 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -76,9 +76,9 @@ private[spark] abstract class WebUI( def attachPage(page: WebUIPage) { val pagePath = "/" + page.prefix val renderHandler = createServletHandler(pagePath, - (request: HttpServletRequest) => page.render(request), securityManager, basePath) + (request: HttpServletRequest) => page.render(request), securityManager, conf, basePath) val renderJsonHandler = createServletHandler(pagePath.stripSuffix("/") + "/json", - (request: HttpServletRequest) => page.renderJson(request), securityManager, basePath) + (request: HttpServletRequest) => page.renderJson(request), securityManager, conf, basePath) attachHandler(renderHandler) attachHandler(renderJsonHandler) pageToHandlers.getOrElseUpdate(page, ArrayBuffer[ServletContextHandler]()) From 896edb51ab7a88bbb31259e565311a9be6f2ca6d Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Wed, 16 Sep 2015 13:20:39 -0700 Subject: [PATCH 312/802] [SPARK-10050] [SPARKR] Support collecting data of MapType in DataFrame. 1. Support collecting data of MapType from DataFrame. 2. Support data of MapType in createDataFrame. Author: Sun Rui Closes #8711 from sun-rui/SPARK-10050. --- R/pkg/R/SQLContext.R | 5 +- R/pkg/R/deserialize.R | 14 +++++ R/pkg/R/schema.R | 34 ++++++++--- R/pkg/inst/tests/test_sparkSQL.R | 56 +++++++++++++++---- .../scala/org/apache/spark/api/r/SerDe.scala | 31 ++++++++++ .../org/apache/spark/sql/api/r/SQLUtils.scala | 6 ++ 6 files changed, 123 insertions(+), 23 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 4ac057d0f2d83..1c58fd96d750a 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -41,10 +41,7 @@ infer_type <- function(x) { if (type == "map") { stopifnot(length(x) > 0) key <- ls(x)[[1]] - list(type = "map", - keyType = "string", - valueType = infer_type(get(key, x)), - valueContainsNull = TRUE) + paste0("map") } else if (type == "array") { stopifnot(length(x) > 0) names <- names(x) diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index d1858ec227b56..ce88d0b071b72 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -50,6 +50,7 @@ readTypedObject <- function(con, type) { "t" = readTime(con), "a" = readArray(con), "l" = readList(con), + "e" = readEnv(con), "n" = NULL, "j" = getJobj(readString(con)), stop(paste("Unsupported type for deserialization", type))) @@ -121,6 +122,19 @@ readList <- function(con) { } } +readEnv <- function(con) { + env <- new.env() + len <- readInt(con) + if (len > 0) { + for (i in 1:len) { + key <- readString(con) + value <- readObject(con) + env[[key]] <- value + } + } + env +} + readRaw <- function(con) { dataLen <- readInt(con) readBin(con, raw(), as.integer(dataLen), endian = "big") diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 62d4f73878d29..8df1563f8ebc0 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -131,13 +131,33 @@ checkType <- function(type) { if (type %in% primtiveTypes) { return() } else { - m <- regexec("^array<(.*)>$", type) - matchedStrings <- regmatches(type, m) - if (length(matchedStrings[[1]]) >= 2) { - elemType <- matchedStrings[[1]][2] - checkType(elemType) - return() - } + # Check complex types + firstChar <- substr(type, 1, 1) + switch (firstChar, + a = { + # Array type + m <- regexec("^array<(.*)>$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 2) { + elemType <- matchedStrings[[1]][2] + checkType(elemType) + return() + } + }, + m = { + # Map type + m <- regexec("^map<(.*),(.*)>$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 3) { + keyType <- matchedStrings[[1]][2] + if (keyType != "string" && keyType != "character") { + stop("Key type in a map must be string or character") + } + valueType <- matchedStrings[[1]][3] + checkType(valueType) + return() + } + }) } stop(paste("Unsupported type for Dataframe:", type)) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 98d4402d368e1..e159a69584274 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -57,7 +57,7 @@ mockLinesComplexType <- complexTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLinesComplexType, complexTypeJsonPath) -test_that("infer types", { +test_that("infer types and check types", { expect_equal(infer_type(1L), "integer") expect_equal(infer_type(1.0), "double") expect_equal(infer_type("abc"), "string") @@ -72,9 +72,9 @@ test_that("infer types", { checkStructField(testStruct$fields()[[2]], "b", "StringType", TRUE) e <- new.env() assign("a", 1L, envir = e) - expect_equal(infer_type(e), - list(type = "map", keyType = "string", valueType = "integer", - valueContainsNull = TRUE)) + expect_equal(infer_type(e), "map") + + expect_error(checkType("map"), "Key type in a map must be string or character") }) test_that("structType and structField", { @@ -242,7 +242,7 @@ test_that("create DataFrame with different data types", { expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) }) -test_that("create DataFrame with nested array and struct", { +test_that("create DataFrame with nested array and map", { # e <- new.env() # assign("n", 3L, envir = e) # l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L)) @@ -253,21 +253,35 @@ test_that("create DataFrame with nested array and struct", { # ldf <- collect(df) # expect_equal(ldf[1,], l[[1]]) + # ArrayType and MapType + e <- new.env() + assign("n", 3L, envir = e) - # ArrayType only for now - l <- list(as.list(1:10), list("a", "b")) - df <- createDataFrame(sqlContext, list(l), c("a", "b")) - expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"))) + l <- list(as.list(1:10), list("a", "b"), e) + df <- createDataFrame(sqlContext, list(l), c("a", "b", "c")) + expect_equal(dtypes(df), list(c("a", "array"), + c("b", "array"), + c("c", "map"))) expect_equal(count(df), 1) ldf <- collect(df) - expect_equal(names(ldf), c("a", "b")) + expect_equal(names(ldf), c("a", "b", "c")) expect_equal(ldf[1, 1][[1]], l[[1]]) expect_equal(ldf[1, 2][[1]], l[[2]]) + e <- ldf$c[[1]] + expect_equal(class(e), "environment") + expect_equal(ls(e), "n") + expect_equal(e$n, 3L) }) +# For test map type in DataFrame +mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}", + "{\"name\":\"Alice\",\"info\":{\"age\":20,\"height\":164.3}}", + "{\"name\":\"David\",\"info\":{\"age\":60,\"height\":180}}") +mapTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") +writeLines(mockLinesMapType, mapTypeJsonPath) + test_that("Collect DataFrame with complex types", { - # only ArrayType now - # TODO: tests for StructType and MapType after they are supported + # ArrayType df <- jsonFile(sqlContext, complexTypeJsonPath) ldf <- collect(df) @@ -277,6 +291,24 @@ test_that("Collect DataFrame with complex types", { expect_equal(ldf$c1, list(list(1, 2, 3), list(4, 5, 6), list (7, 8, 9))) expect_equal(ldf$c2, list(list("a", "b", "c"), list("d", "e", "f"), list ("g", "h", "i"))) expect_equal(ldf$c3, list(list(1.0, 2.0, 3.0), list(4.0, 5.0, 6.0), list (7.0, 8.0, 9.0))) + + # MapType + schema <- structType(structField("name", "string"), + structField("info", "map")) + df <- read.df(sqlContext, mapTypeJsonPath, "json", schema) + expect_equal(dtypes(df), list(c("name", "string"), + c("info", "map"))) + ldf <- collect(df) + expect_equal(nrow(ldf), 3) + expect_equal(ncol(ldf), 2) + expect_equal(names(ldf), c("name", "info")) + expect_equal(ldf$name, c("Bob", "Alice", "David")) + bob <- ldf$info[[1]] + expect_equal(class(bob), "environment") + expect_equal(bob$age, 16) + expect_equal(bob$height, 176.5) + + # TODO: tests for StructType after it is supported }) test_that("jsonFile() on a local file returns a DataFrame", { diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 3c92bb7a1c73c..0c78613e406e1 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -209,11 +209,23 @@ private[spark] object SerDe { case "array" => dos.writeByte('a') // Array of objects case "list" => dos.writeByte('l') + case "map" => dos.writeByte('e') case "jobj" => dos.writeByte('j') case _ => throw new IllegalArgumentException(s"Invalid type $typeStr") } } + private def writeKeyValue(dos: DataOutputStream, key: Object, value: Object): Unit = { + if (key == null) { + throw new IllegalArgumentException("Key in map can't be null.") + } else if (!key.isInstanceOf[String]) { + throw new IllegalArgumentException(s"Invalid map key type: ${key.getClass.getName}") + } + + writeString(dos, key.asInstanceOf[String]) + writeObject(dos, value) + } + def writeObject(dos: DataOutputStream, obj: Object): Unit = { if (obj == null) { writeType(dos, "void") @@ -306,6 +318,25 @@ private[spark] object SerDe { writeInt(dos, v.length) v.foreach(elem => writeObject(dos, elem)) + // Handle map + case v: java.util.Map[_, _] => + writeType(dos, "map") + writeInt(dos, v.size) + val iter = v.entrySet.iterator + while(iter.hasNext) { + val entry = iter.next + val key = entry.getKey + val value = entry.getValue + + writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) + } + case v: scala.collection.Map[_, _] => + writeType(dos, "map") + writeInt(dos, v.size) + v.foreach { case (key, value) => + writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) + } + case _ => writeType(dos, "jobj") writeJObj(dos, value) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index d4b834adb6e39..f45d119c8cfdf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -64,6 +64,12 @@ private[r] object SQLUtils { case r"\Aarray<(.*)${elemType}>\Z" => { org.apache.spark.sql.types.ArrayType(getSQLDataType(elemType)) } + case r"\Amap<(.*)${keyType},(.*)${valueType}>\Z" => { + if (keyType != "string" && keyType != "character") { + throw new IllegalArgumentException("Key type of a map must be string or character") + } + org.apache.spark.sql.types.MapType(getSQLDataType(keyType), getSQLDataType(valueType)) + } case _ => throw new IllegalArgumentException(s"Invaid type $dataType") } } From d39f15ea2b8bed5342d2f8e3c1936f915c470783 Mon Sep 17 00:00:00 2001 From: Kevin Cox Date: Wed, 16 Sep 2015 15:30:17 -0700 Subject: [PATCH 313/802] [SPARK-9794] [SQL] Fix datetime parsing in SparkSQL. This fixes https://issues.apache.org/jira/browse/SPARK-9794 by using a real ISO8601 parser. (courtesy of the xml component of the standard java library) cc: angelini Author: Kevin Cox Closes #8396 from kevincox/kevincox-sql-time-parsing. --- .../sql/catalyst/util/DateTimeUtils.scala | 27 ++++++---------- .../catalyst/util/DateTimeUtilsSuite.scala | 32 +++++++++++++++++++ 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 687ca000d12bb..400c4327be1c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import java.util.{TimeZone, Calendar} +import javax.xml.bind.DatatypeConverter; import org.apache.spark.unsafe.types.UTF8String @@ -109,30 +110,22 @@ object DateTimeUtils { } def stringToTime(s: String): java.util.Date = { - if (!s.contains('T')) { + var indexOfGMT = s.indexOf("GMT"); + if (indexOfGMT != -1) { + // ISO8601 with a weird time zone specifier (2000-01-01T00:00GMT+01:00) + val s0 = s.substring(0, indexOfGMT) + val s1 = s.substring(indexOfGMT + 3) + // Mapped to 2000-01-01T00:00+01:00 + stringToTime(s0 + s1) + } else if (!s.contains('T')) { // JDBC escape string if (s.contains(' ')) { Timestamp.valueOf(s) } else { Date.valueOf(s) } - } else if (s.endsWith("Z")) { - // this is zero timezone of ISO8601 - stringToTime(s.substring(0, s.length - 1) + "GMT-00:00") - } else if (s.indexOf("GMT") == -1) { - // timezone with ISO8601 - val inset = "+00.00".length - val s0 = s.substring(0, s.length - inset) - val s1 = s.substring(s.length - inset, s.length) - if (s0.substring(s0.lastIndexOf(':')).contains('.')) { - stringToTime(s0 + "GMT" + s1) - } else { - stringToTime(s0 + ".0GMT" + s1) - } } else { - // ISO8601 with GMT insert - val ISO8601GMT: SimpleDateFormat = new SimpleDateFormat( "yyyy-MM-dd'T'HH:mm:ss.SSSz" ) - ISO8601GMT.parse(s) + DatatypeConverter.parseDateTime(s).getTime() } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 6b9a11f0ff743..46335941b62d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -136,6 +136,38 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(stringToDate(UTF8String.fromString("2015-031-8")).isEmpty) } + test("string to time") { + // Tests with UTC. + var c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + c.set(Calendar.MILLISECOND, 0) + + c.set(1900, 0, 1, 0, 0, 0) + assert(stringToTime("1900-01-01T00:00:00GMT-00:00") === c.getTime()) + + c.set(2000, 11, 30, 10, 0, 0) + assert(stringToTime("2000-12-30T10:00:00Z") === c.getTime()) + + // Tests with set time zone. + c.setTimeZone(TimeZone.getTimeZone("GMT-04:00")) + c.set(Calendar.MILLISECOND, 0) + + c.set(1900, 0, 1, 0, 0, 0) + assert(stringToTime("1900-01-01T00:00:00-04:00") === c.getTime()) + + c.set(1900, 0, 1, 0, 0, 0) + assert(stringToTime("1900-01-01T00:00:00GMT-04:00") === c.getTime()) + + // Tests with local time zone. + c.setTimeZone(TimeZone.getDefault()) + c.set(Calendar.MILLISECOND, 0) + + c.set(2000, 11, 30, 0, 0, 0) + assert(stringToTime("2000-12-30") === new Date(c.getTimeInMillis())) + + c.set(2000, 11, 30, 10, 0, 0) + assert(stringToTime("2000-12-30 10:00:00") === new Timestamp(c.getTimeInMillis())) + } + test("string to timestamp") { var c = Calendar.getInstance() c.set(1969, 11, 31, 16, 0, 0) From 49c649fa0b6affed108dbae85373b4b7247b338c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 16 Sep 2015 15:32:01 -0700 Subject: [PATCH 314/802] Tiny style fix for d39f15ea2b8bed5342d2f8e3c1936f915c470783. --- .../org/apache/spark/sql/catalyst/util/DateTimeUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 400c4327be1c7..781ed1688a327 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import java.util.{TimeZone, Calendar} -import javax.xml.bind.DatatypeConverter; +import javax.xml.bind.DatatypeConverter import org.apache.spark.unsafe.types.UTF8String From 69c9830d288d5b8d7f0abe7c8a65a4c966580a49 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 17 Sep 2015 00:48:57 -0700 Subject: [PATCH 315/802] [MINOR] [CORE] Fixes minor variable name typo Author: Cheng Lian Closes #8784 from liancheng/typo-fix. --- .../apache/spark/serializer/GenericAvroSerializerSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala index bc9f3708ed69d..87f25e7245e1f 100644 --- a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala @@ -76,9 +76,9 @@ class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext { test("caches previously seen schemas") { val genericSer = new GenericAvroSerializer(conf.getAvroSchema) val compressedSchema = genericSer.compress(schema) - val decompressedScheam = genericSer.decompress(ByteBuffer.wrap(compressedSchema)) + val decompressedSchema = genericSer.decompress(ByteBuffer.wrap(compressedSchema)) assert(compressedSchema.eq(genericSer.compress(schema))) - assert(decompressedScheam.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema)))) + assert(decompressedSchema.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema)))) } } From c633ed3260140f1288f326acc4d7a10dcd2e27d5 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:43:59 -0700 Subject: [PATCH 316/802] [SPARK-10284] [ML] [PYSPARK] [DOCS] Add @since annotation to pyspark.ml.tuning Author: Yu ISHIKAWA Closes #8694 from yu-iskw/SPARK-10284. --- python/pyspark/ml/tuning.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index cae778869e9c5..ab5621f45c72c 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -18,6 +18,7 @@ import itertools import numpy as np +from pyspark import since from pyspark.ml.param import Params, Param from pyspark.ml import Estimator, Model from pyspark.ml.util import keyword_only @@ -47,11 +48,14 @@ class ParamGridBuilder(object): True >>> all([m in expected for m in output]) True + + .. versionadded:: 1.4.0 """ def __init__(self): self._param_grid = {} + @since("1.4.0") def addGrid(self, param, values): """ Sets the given parameters in this grid to fixed values. @@ -60,6 +64,7 @@ def addGrid(self, param, values): return self + @since("1.4.0") def baseOn(self, *args): """ Sets the given parameters in this grid to fixed values. @@ -73,6 +78,7 @@ def baseOn(self, *args): return self + @since("1.4.0") def build(self): """ Builds and returns all combinations of parameters specified @@ -104,6 +110,8 @@ class CrossValidator(Estimator): >>> cvModel = cv.fit(dataset) >>> evaluator.evaluate(cvModel.transform(dataset)) 0.8333... + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -142,6 +150,7 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF self._set(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): """ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): @@ -150,6 +159,7 @@ def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, num kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.4.0") def setEstimator(self, value): """ Sets the value of :py:attr:`estimator`. @@ -157,12 +167,14 @@ def setEstimator(self, value): self._paramMap[self.estimator] = value return self + @since("1.4.0") def getEstimator(self): """ Gets the value of estimator or its default value. """ return self.getOrDefault(self.estimator) + @since("1.4.0") def setEstimatorParamMaps(self, value): """ Sets the value of :py:attr:`estimatorParamMaps`. @@ -170,12 +182,14 @@ def setEstimatorParamMaps(self, value): self._paramMap[self.estimatorParamMaps] = value return self + @since("1.4.0") def getEstimatorParamMaps(self): """ Gets the value of estimatorParamMaps or its default value. """ return self.getOrDefault(self.estimatorParamMaps) + @since("1.4.0") def setEvaluator(self, value): """ Sets the value of :py:attr:`evaluator`. @@ -183,12 +197,14 @@ def setEvaluator(self, value): self._paramMap[self.evaluator] = value return self + @since("1.4.0") def getEvaluator(self): """ Gets the value of evaluator or its default value. """ return self.getOrDefault(self.evaluator) + @since("1.4.0") def setNumFolds(self, value): """ Sets the value of :py:attr:`numFolds`. @@ -196,6 +212,7 @@ def setNumFolds(self, value): self._paramMap[self.numFolds] = value return self + @since("1.4.0") def getNumFolds(self): """ Gets the value of numFolds or its default value. @@ -231,7 +248,15 @@ def _fit(self, dataset): bestModel = est.fit(dataset, epm[bestIndex]) return CrossValidatorModel(bestModel) + @since("1.4.0") def copy(self, extra=None): + """ + Creates a copy of this instance with a randomly generated uid + and some extra params. This copies creates a deep copy of + the embedded paramMap, and copies the embedded and extra parameters over. + :param extra: Extra parameters to copy to the new instance + :return: Copy of this instance + """ if extra is None: extra = dict() newCV = Params.copy(self, extra) @@ -246,6 +271,8 @@ def copy(self, extra=None): class CrossValidatorModel(Model): """ Model from k-fold cross validation. + + .. versionadded:: 1.4.0 """ def __init__(self, bestModel): @@ -256,6 +283,7 @@ def __init__(self, bestModel): def _transform(self, dataset): return self.bestModel.transform(dataset) + @since("1.4.0") def copy(self, extra=None): """ Creates a copy of this instance with a randomly generated uid From 29bf8aa5a51fdd8c2600533297f991e14fa27c03 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:45:20 -0700 Subject: [PATCH 317/802] [SPARK-10283] [ML] [PYSPARK] [DOCS] Add @since annotation to pyspark.ml.regression Author: Yu ISHIKAWA Closes #8693 from yu-iskw/SPARK-10283. --- python/pyspark/ml/regression.py | 65 +++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index a9503608b7f25..21d454f9003bb 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -15,6 +15,7 @@ # limitations under the License. # +from pyspark import since from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * @@ -62,6 +63,8 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. + + .. versionadded:: 1.4.0 """ @keyword_only @@ -81,6 +84,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, standardization=True): @@ -96,13 +100,31 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return LinearRegressionModel(java_model) + @since("1.4.0") + def setElasticNetParam(self, value): + """ + Sets the value of :py:attr:`elasticNetParam`. + """ + self._paramMap[self.elasticNetParam] = value + return self + + @since("1.4.0") + def getElasticNetParam(self): + """ + Gets the value of elasticNetParam or its default value. + """ + return self.getOrDefault(self.elasticNetParam) + class LinearRegressionModel(JavaModel): """ Model fitted by LinearRegression. + + .. versionadded:: 1.4.0 """ @property + @since("1.4.0") def weights(self): """ Model weights. @@ -110,6 +132,7 @@ def weights(self): return self._call_java("weights") @property + @since("1.4.0") def intercept(self): """ Model intercept. @@ -162,6 +185,8 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -193,6 +218,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, @@ -209,6 +235,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return DecisionTreeRegressionModel(java_model) + @since("1.4.0") def setImpurity(self, value): """ Sets the value of :py:attr:`impurity`. @@ -216,6 +243,7 @@ def setImpurity(self, value): self._paramMap[self.impurity] = value return self + @since("1.4.0") def getImpurity(self): """ Gets the value of impurity or its default value. @@ -225,13 +253,19 @@ def getImpurity(self): @inherit_doc class DecisionTreeModel(JavaModel): + """Abstraction for Decision Tree models. + + .. versionadded:: 1.5.0 + """ @property + @since("1.5.0") def numNodes(self): """Return number of nodes of the decision tree.""" return self._call_java("numNodes") @property + @since("1.5.0") def depth(self): """Return depth of the decision tree.""" return self._call_java("depth") @@ -242,8 +276,13 @@ def __repr__(self): @inherit_doc class TreeEnsembleModels(JavaModel): + """Represents a tree ensemble model. + + .. versionadded:: 1.5.0 + """ @property + @since("1.5.0") def treeWeights(self): """Return the weights for each tree""" return list(self._call_java("javaTreeWeights")) @@ -256,6 +295,8 @@ def __repr__(self): class DecisionTreeRegressionModel(DecisionTreeModel): """ Model fitted by DecisionTreeRegressor. + + .. versionadded:: 1.4.0 """ @@ -282,6 +323,8 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 0.5 + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -336,6 +379,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, @@ -353,6 +397,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return RandomForestRegressionModel(java_model) + @since("1.4.0") def setImpurity(self, value): """ Sets the value of :py:attr:`impurity`. @@ -360,12 +405,14 @@ def setImpurity(self, value): self._paramMap[self.impurity] = value return self + @since("1.4.0") def getImpurity(self): """ Gets the value of impurity or its default value. """ return self.getOrDefault(self.impurity) + @since("1.4.0") def setSubsamplingRate(self, value): """ Sets the value of :py:attr:`subsamplingRate`. @@ -373,12 +420,14 @@ def setSubsamplingRate(self, value): self._paramMap[self.subsamplingRate] = value return self + @since("1.4.0") def getSubsamplingRate(self): """ Gets the value of subsamplingRate or its default value. """ return self.getOrDefault(self.subsamplingRate) + @since("1.4.0") def setNumTrees(self, value): """ Sets the value of :py:attr:`numTrees`. @@ -386,12 +435,14 @@ def setNumTrees(self, value): self._paramMap[self.numTrees] = value return self + @since("1.4.0") def getNumTrees(self): """ Gets the value of numTrees or its default value. """ return self.getOrDefault(self.numTrees) + @since("1.4.0") def setFeatureSubsetStrategy(self, value): """ Sets the value of :py:attr:`featureSubsetStrategy`. @@ -399,6 +450,7 @@ def setFeatureSubsetStrategy(self, value): self._paramMap[self.featureSubsetStrategy] = value return self + @since("1.4.0") def getFeatureSubsetStrategy(self): """ Gets the value of featureSubsetStrategy or its default value. @@ -409,6 +461,8 @@ def getFeatureSubsetStrategy(self): class RandomForestRegressionModel(TreeEnsembleModels): """ Model fitted by RandomForestRegressor. + + .. versionadded:: 1.4.0 """ @@ -435,6 +489,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -481,6 +537,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, @@ -498,6 +555,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return GBTRegressionModel(java_model) + @since("1.4.0") def setLossType(self, value): """ Sets the value of :py:attr:`lossType`. @@ -505,12 +563,14 @@ def setLossType(self, value): self._paramMap[self.lossType] = value return self + @since("1.4.0") def getLossType(self): """ Gets the value of lossType or its default value. """ return self.getOrDefault(self.lossType) + @since("1.4.0") def setSubsamplingRate(self, value): """ Sets the value of :py:attr:`subsamplingRate`. @@ -518,12 +578,14 @@ def setSubsamplingRate(self, value): self._paramMap[self.subsamplingRate] = value return self + @since("1.4.0") def getSubsamplingRate(self): """ Gets the value of subsamplingRate or its default value. """ return self.getOrDefault(self.subsamplingRate) + @since("1.4.0") def setStepSize(self, value): """ Sets the value of :py:attr:`stepSize`. @@ -531,6 +593,7 @@ def setStepSize(self, value): self._paramMap[self.stepSize] = value return self + @since("1.4.0") def getStepSize(self): """ Gets the value of stepSize or its default value. @@ -541,6 +604,8 @@ def getStepSize(self): class GBTRegressionModel(TreeEnsembleModels): """ Model fitted by GBTRegressor. + + .. versionadded:: 1.4.0 """ From 0ded87a4d49d4484e202bd2ec781821b57b5882c Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:47:21 -0700 Subject: [PATCH 318/802] [SPARK-10281] [ML] [PYSPARK] [DOCS] Add @since annotation to pyspark.ml.clustering Author: Yu ISHIKAWA Closes #8691 from yu-iskw/SPARK-10281. --- python/pyspark/ml/clustering.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index cb4c16e25a7a3..7bb8ab94e17df 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -15,6 +15,7 @@ # limitations under the License. # +from pyspark import since from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * @@ -26,8 +27,11 @@ class KMeansModel(JavaModel): """ Model fitted by KMeans. + + .. versionadded:: 1.5.0 """ + @since("1.5.0") def clusterCenters(self): """Get the cluster centers, represented as a list of NumPy arrays.""" return [c.toArray() for c in self._call_java("clusterCenters")] @@ -55,6 +59,8 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol True >>> rows[2].prediction == rows[3].prediction True + + .. versionadded:: 1.5.0 """ # a placeholder to make it appear in the generated doc @@ -88,6 +94,7 @@ def _create_model(self, java_model): return KMeansModel(java_model) @keyword_only + @since("1.5.0") def setParams(self, featuresCol="features", predictionCol="prediction", k=2, initMode="k-means||", initSteps=5, tol=1e-4, maxIter=20, seed=None): """ @@ -99,6 +106,7 @@ def setParams(self, featuresCol="features", predictionCol="prediction", k=2, kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("1.5.0") def setK(self, value): """ Sets the value of :py:attr:`k`. @@ -110,12 +118,14 @@ def setK(self, value): self._paramMap[self.k] = value return self + @since("1.5.0") def getK(self): """ Gets the value of `k` """ return self.getOrDefault(self.k) + @since("1.5.0") def setInitMode(self, value): """ Sets the value of :py:attr:`initMode`. @@ -130,12 +140,14 @@ def setInitMode(self, value): self._paramMap[self.initMode] = value return self + @since("1.5.0") def getInitMode(self): """ Gets the value of `initMode` """ return self.getOrDefault(self.initMode) + @since("1.5.0") def setInitSteps(self, value): """ Sets the value of :py:attr:`initSteps`. @@ -147,6 +159,7 @@ def setInitSteps(self, value): self._paramMap[self.initSteps] = value return self + @since("1.5.0") def getInitSteps(self): """ Gets the value of `initSteps` From 39b44cb52eb225469eb4ccdf696f0bc6405b9184 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:48:45 -0700 Subject: [PATCH 319/802] [SPARK-10278] [MLLIB] [PYSPARK] Add @since annotation to pyspark.mllib.tree Author: Yu ISHIKAWA Closes #8685 from yu-iskw/SPARK-10278. --- python/pyspark/mllib/tree.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 372b86a7c95d9..0001b60093a69 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -19,7 +19,7 @@ import random -from pyspark import SparkContext, RDD +from pyspark import SparkContext, RDD, since from pyspark.mllib.common import callMLlibFunc, inherit_doc, JavaModelWrapper from pyspark.mllib.linalg import _convert_to_vector from pyspark.mllib.regression import LabeledPoint @@ -30,6 +30,11 @@ class TreeEnsembleModel(JavaModelWrapper, JavaSaveable): + """TreeEnsembleModel + + .. versionadded:: 1.3.0 + """ + @since("1.3.0") def predict(self, x): """ Predict values for a single data point or an RDD of points using @@ -45,12 +50,14 @@ def predict(self, x): else: return self.call("predict", _convert_to_vector(x)) + @since("1.3.0") def numTrees(self): """ Get number of trees in ensemble. """ return self.call("numTrees") + @since("1.3.0") def totalNumNodes(self): """ Get total number of nodes, summed over all trees in the @@ -62,6 +69,7 @@ def __repr__(self): """ Summary of model """ return self._java_model.toString() + @since("1.3.0") def toDebugString(self): """ Full model """ return self._java_model.toDebugString() @@ -72,7 +80,10 @@ class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader): .. note:: Experimental A decision tree model for classification or regression. + + .. versionadded:: 1.1.0 """ + @since("1.1.0") def predict(self, x): """ Predict the label of one or more examples. @@ -90,16 +101,23 @@ def predict(self, x): else: return self.call("predict", _convert_to_vector(x)) + @since("1.1.0") def numNodes(self): + """Get number of nodes in tree, including leaf nodes.""" return self._java_model.numNodes() + @since("1.1.0") def depth(self): + """Get depth of tree. + E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. + """ return self._java_model.depth() def __repr__(self): """ summary of model. """ return self._java_model.toString() + @since("1.2.0") def toDebugString(self): """ full model. """ return self._java_model.toDebugString() @@ -115,6 +133,8 @@ class DecisionTree(object): Learning algorithm for a decision tree model for classification or regression. + + .. versionadded:: 1.1.0 """ @classmethod @@ -127,6 +147,7 @@ def _train(cls, data, type, numClasses, features, impurity="gini", maxDepth=5, m return DecisionTreeModel(model) @classmethod + @since("1.1.0") def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0): @@ -185,6 +206,7 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) @classmethod + @since("1.1.0") def trainRegressor(cls, data, categoricalFeaturesInfo, impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0): @@ -239,6 +261,8 @@ class RandomForestModel(TreeEnsembleModel, JavaLoader): .. note:: Experimental Represents a random forest model. + + .. versionadded:: 1.2.0 """ @classmethod @@ -252,6 +276,8 @@ class RandomForest(object): Learning algorithm for a random forest model for classification or regression. + + .. versionadded:: 1.2.0 """ supportedFeatureSubsetStrategies = ("auto", "all", "sqrt", "log2", "onethird") @@ -271,6 +297,7 @@ def _train(cls, data, algo, numClasses, categoricalFeaturesInfo, numTrees, return RandomForestModel(model) @classmethod + @since("1.2.0") def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto", impurity="gini", maxDepth=4, maxBins=32, seed=None): @@ -352,6 +379,7 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, maxDepth, maxBins, seed) @classmethod + @since("1.2.0") def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto", impurity="variance", maxDepth=4, maxBins=32, seed=None): """ @@ -418,6 +446,8 @@ class GradientBoostedTreesModel(TreeEnsembleModel, JavaLoader): .. note:: Experimental Represents a gradient-boosted tree model. + + .. versionadded:: 1.3.0 """ @classmethod @@ -431,6 +461,8 @@ class GradientBoostedTrees(object): Learning algorithm for a gradient boosted trees model for classification or regression. + + .. versionadded:: 1.3.0 """ @classmethod @@ -443,6 +475,7 @@ def _train(cls, data, algo, categoricalFeaturesInfo, return GradientBoostedTreesModel(model) @classmethod + @since("1.3.0") def trainClassifier(cls, data, categoricalFeaturesInfo, loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32): @@ -505,6 +538,7 @@ def trainClassifier(cls, data, categoricalFeaturesInfo, loss, numIterations, learningRate, maxDepth, maxBins) @classmethod + @since("1.3.0") def trainRegressor(cls, data, categoricalFeaturesInfo, loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32): From 4a0b56e8dbb3713b16e58738201d838ffc4b258b Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:50:00 -0700 Subject: [PATCH 320/802] [SPARK-10279] [MLLIB] [PYSPARK] [DOCS] Add @since annotation to pyspark.mllib.util Author: Yu ISHIKAWA Closes #8689 from yu-iskw/SPARK-10279. --- python/pyspark/mllib/util.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 10a1e4b3eb0fc..39bc6586dd582 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -23,7 +23,7 @@ xrange = range basestring = str -from pyspark import SparkContext +from pyspark import SparkContext, since from pyspark.mllib.common import callMLlibFunc, inherit_doc from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector @@ -32,6 +32,8 @@ class MLUtils(object): """ Helper methods to load, save and pre-process data used in MLlib. + + .. versionadded:: 1.0.0 """ @staticmethod @@ -69,6 +71,7 @@ def _convert_labeled_point_to_libsvm(p): return " ".join(items) @staticmethod + @since("1.0.0") def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None): """ Loads labeled data in the LIBSVM format into an RDD of @@ -123,6 +126,7 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None return parsed.map(lambda x: LabeledPoint(x[0], Vectors.sparse(numFeatures, x[1], x[2]))) @staticmethod + @since("1.0.0") def saveAsLibSVMFile(data, dir): """ Save labeled data in LIBSVM format. @@ -147,6 +151,7 @@ def saveAsLibSVMFile(data, dir): lines.saveAsTextFile(dir) @staticmethod + @since("1.1.0") def loadLabeledPoints(sc, path, minPartitions=None): """ Load labeled points saved using RDD.saveAsTextFile. @@ -172,6 +177,7 @@ def loadLabeledPoints(sc, path, minPartitions=None): return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions) @staticmethod + @since("1.5.0") def appendBias(data): """ Returns a new vector with `1.0` (bias) appended to @@ -186,6 +192,7 @@ def appendBias(data): return _convert_to_vector(np.append(vec.toArray(), 1.0)) @staticmethod + @since("1.5.0") def loadVectors(sc, path): """ Loads vectors saved using `RDD[Vector].saveAsTextFile` @@ -197,6 +204,8 @@ def loadVectors(sc, path): class Saveable(object): """ Mixin for models and transformers which may be saved as files. + + .. versionadded:: 1.3.0 """ def save(self, sc, path): @@ -222,9 +231,13 @@ class JavaSaveable(Saveable): """ Mixin for models that provide save() through their Scala implementation. + + .. versionadded:: 1.3.0 """ + @since("1.3.0") def save(self, sc, path): + """Save this model to the given path.""" if not isinstance(sc, SparkContext): raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) if not isinstance(path, basestring): @@ -235,6 +248,8 @@ def save(self, sc, path): class Loader(object): """ Mixin for classes which can load saved models from files. + + .. versionadded:: 1.3.0 """ @classmethod @@ -256,6 +271,8 @@ class JavaLoader(Loader): """ Mixin for classes which can load saved models using its Scala implementation. + + .. versionadded:: 1.3.0 """ @classmethod @@ -280,15 +297,21 @@ def _load_java(cls, sc, path): return java_obj.load(sc._jsc.sc(), path) @classmethod + @since("1.3.0") def load(cls, sc, path): + """Load a model from the given path.""" java_model = cls._load_java(sc, path) return cls(java_model) class LinearDataGenerator(object): - """Utils for generating linear data""" + """Utils for generating linear data. + + .. versionadded:: 1.5.0 + """ @staticmethod + @since("1.5.0") def generateLinearInput(intercept, weights, xMean, xVariance, nPoints, seed, eps): """ @@ -311,6 +334,7 @@ def generateLinearInput(intercept, weights, xMean, xVariance, xVariance, int(nPoints), int(seed), float(eps))) @staticmethod + @since("1.5.0") def generateLinearRDD(sc, nexamples, nfeatures, eps, nParts=2, intercept=0.0): """ From c74d38fd8faf8cba981cf934341d24b9a3167025 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:50:46 -0700 Subject: [PATCH 321/802] [SPARK-10274] [MLLIB] Add @since annotation to pyspark.mllib.fpm Author: Yu ISHIKAWA Closes #8665 from yu-iskw/SPARK-10274. --- python/pyspark/mllib/fpm.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index bdc4a132b1b18..bdabba9602a8c 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -19,7 +19,7 @@ from numpy import array from collections import namedtuple -from pyspark import SparkContext +from pyspark import SparkContext, since from pyspark.rdd import ignore_unicode_prefix from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc @@ -41,8 +41,11 @@ class FPGrowthModel(JavaModelWrapper): >>> model = FPGrowth.train(rdd, 0.6, 2) >>> sorted(model.freqItemsets().collect()) [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ... + + .. versionadded:: 1.4.0 """ + @since("1.4.0") def freqItemsets(self): """ Returns the frequent itemsets of this model. @@ -55,9 +58,12 @@ class FPGrowth(object): .. note:: Experimental A Parallel FP-growth algorithm to mine frequent itemsets. + + .. versionadded:: 1.4.0 """ @classmethod + @since("1.4.0") def train(cls, data, minSupport=0.3, numPartitions=-1): """ Computes an FP-Growth model that contains frequent itemsets. @@ -74,6 +80,8 @@ def train(cls, data, minSupport=0.3, numPartitions=-1): class FreqItemset(namedtuple("FreqItemset", ["items", "freq"])): """ Represents an (items, freq) tuple. + + .. versionadded:: 1.4.0 """ From 268088b899e6e165e746aed87840d47bfaf50c43 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 17 Sep 2015 08:51:19 -0700 Subject: [PATCH 322/802] [SPARK-10282] [ML] [PYSPARK] [DOCS] Add @since annotation to pyspark.ml.recommendation Author: Yu ISHIKAWA Closes #8692 from yu-iskw/SPARK-10282. --- python/pyspark/ml/recommendation.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index b06099ac0aee6..ec5748a1cfe94 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -15,6 +15,7 @@ # limitations under the License. # +from pyspark import since from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * @@ -80,6 +81,8 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha Row(user=1, item=0, prediction=3.19...) >>> predictions[2] Row(user=2, item=0, prediction=-1.15...) + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -122,6 +125,7 @@ def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemB self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10): @@ -137,6 +141,7 @@ def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItem def _create_model(self, java_model): return ALSModel(java_model) + @since("1.4.0") def setRank(self, value): """ Sets the value of :py:attr:`rank`. @@ -144,12 +149,14 @@ def setRank(self, value): self._paramMap[self.rank] = value return self + @since("1.4.0") def getRank(self): """ Gets the value of rank or its default value. """ return self.getOrDefault(self.rank) + @since("1.4.0") def setNumUserBlocks(self, value): """ Sets the value of :py:attr:`numUserBlocks`. @@ -157,12 +164,14 @@ def setNumUserBlocks(self, value): self._paramMap[self.numUserBlocks] = value return self + @since("1.4.0") def getNumUserBlocks(self): """ Gets the value of numUserBlocks or its default value. """ return self.getOrDefault(self.numUserBlocks) + @since("1.4.0") def setNumItemBlocks(self, value): """ Sets the value of :py:attr:`numItemBlocks`. @@ -170,12 +179,14 @@ def setNumItemBlocks(self, value): self._paramMap[self.numItemBlocks] = value return self + @since("1.4.0") def getNumItemBlocks(self): """ Gets the value of numItemBlocks or its default value. """ return self.getOrDefault(self.numItemBlocks) + @since("1.4.0") def setNumBlocks(self, value): """ Sets both :py:attr:`numUserBlocks` and :py:attr:`numItemBlocks` to the specific value. @@ -183,6 +194,7 @@ def setNumBlocks(self, value): self._paramMap[self.numUserBlocks] = value self._paramMap[self.numItemBlocks] = value + @since("1.4.0") def setImplicitPrefs(self, value): """ Sets the value of :py:attr:`implicitPrefs`. @@ -190,12 +202,14 @@ def setImplicitPrefs(self, value): self._paramMap[self.implicitPrefs] = value return self + @since("1.4.0") def getImplicitPrefs(self): """ Gets the value of implicitPrefs or its default value. """ return self.getOrDefault(self.implicitPrefs) + @since("1.4.0") def setAlpha(self, value): """ Sets the value of :py:attr:`alpha`. @@ -203,12 +217,14 @@ def setAlpha(self, value): self._paramMap[self.alpha] = value return self + @since("1.4.0") def getAlpha(self): """ Gets the value of alpha or its default value. """ return self.getOrDefault(self.alpha) + @since("1.4.0") def setUserCol(self, value): """ Sets the value of :py:attr:`userCol`. @@ -216,12 +232,14 @@ def setUserCol(self, value): self._paramMap[self.userCol] = value return self + @since("1.4.0") def getUserCol(self): """ Gets the value of userCol or its default value. """ return self.getOrDefault(self.userCol) + @since("1.4.0") def setItemCol(self, value): """ Sets the value of :py:attr:`itemCol`. @@ -229,12 +247,14 @@ def setItemCol(self, value): self._paramMap[self.itemCol] = value return self + @since("1.4.0") def getItemCol(self): """ Gets the value of itemCol or its default value. """ return self.getOrDefault(self.itemCol) + @since("1.4.0") def setRatingCol(self, value): """ Sets the value of :py:attr:`ratingCol`. @@ -242,12 +262,14 @@ def setRatingCol(self, value): self._paramMap[self.ratingCol] = value return self + @since("1.4.0") def getRatingCol(self): """ Gets the value of ratingCol or its default value. """ return self.getOrDefault(self.ratingCol) + @since("1.4.0") def setNonnegative(self, value): """ Sets the value of :py:attr:`nonnegative`. @@ -255,6 +277,7 @@ def setNonnegative(self, value): self._paramMap[self.nonnegative] = value return self + @since("1.4.0") def getNonnegative(self): """ Gets the value of nonnegative or its default value. @@ -265,14 +288,18 @@ def getNonnegative(self): class ALSModel(JavaModel): """ Model fitted by ALS. + + .. versionadded:: 1.4.0 """ @property + @since("1.4.0") def rank(self): """rank of the matrix factorization model""" return self._call_java("rank") @property + @since("1.4.0") def userFactors(self): """ a DataFrame that stores user factors in two columns: `id` and @@ -281,6 +308,7 @@ def userFactors(self): return self._call_java("userFactors") @property + @since("1.4.0") def itemFactors(self): """ a DataFrame that stores item factors in two columns: `id` and From e51345e1e04e439827a07c95887d14ba38333057 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 17 Sep 2015 09:17:43 -0700 Subject: [PATCH 323/802] [SPARK-10077] [DOCS] [ML] Add package info for java of ml/feature Should be the same as SPARK-7808 but use Java for the code example. It would be great to add package doc for `spark.ml.feature`. Author: Holden Karau Closes #8740 from holdenk/SPARK-10077-JAVA-PACKAGE-DOC-FOR-SPARK.ML.FEATURE. --- .../apache/spark/ml/feature/package-info.java | 108 ++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java new file mode 100644 index 0000000000000..c22d2e0cd2d90 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/package-info.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +/** + * Feature transformers + * + * The `ml.feature` package provides common feature transformers that help convert raw data or + * features into more suitable forms for model fitting. + * Most feature transformers are implemented as {@link org.apache.spark.ml.Transformer}s, which + * transforms one {@link org.apache.spark.sql.DataFrame} into another, e.g., + * {@link org.apache.spark.feature.HashingTF}. + * Some feature transformers are implemented as {@link org.apache.spark.ml.Estimator}}s, because the + * transformation requires some aggregated information of the dataset, e.g., document + * frequencies in {@link org.apache.spark.ml.feature.IDF}. + * For those feature transformers, calling {@link org.apache.spark.ml.Estimator#fit} is required to + * obtain the model first, e.g., {@link org.apache.spark.ml.feature.IDFModel}, in order to apply + * transformation. + * The transformation is usually done by appending new columns to the input + * {@link org.apache.spark.sql.DataFrame}, so all input columns are carried over. + * + * We try to make each transformer minimal, so it becomes flexible to assemble feature + * transformation pipelines. + * {@link org.apache.spark.ml.Pipeline} can be used to chain feature transformers, and + * {@link org.apache.spark.ml.feature.VectorAssembler} can be used to combine multiple feature + * transformations, for example: + * + *
+ * 
+ *   import java.util.Arrays;
+ *
+ *   import org.apache.spark.api.java.JavaRDD;
+ *   import static org.apache.spark.sql.types.DataTypes.*;
+ *   import org.apache.spark.sql.types.StructType;
+ *   import org.apache.spark.sql.DataFrame;
+ *   import org.apache.spark.sql.RowFactory;
+ *   import org.apache.spark.sql.Row;
+ *
+ *   import org.apache.spark.ml.feature.*;
+ *   import org.apache.spark.ml.Pipeline;
+ *   import org.apache.spark.ml.PipelineStage;
+ *   import org.apache.spark.ml.PipelineModel;
+ *
+ *  // a DataFrame with three columns: id (integer), text (string), and rating (double).
+ *  StructType schema = createStructType(
+ *    Arrays.asList(
+ *      createStructField("id", IntegerType, false),
+ *      createStructField("text", StringType, false),
+ *      createStructField("rating", DoubleType, false)));
+ *  JavaRDD rowRDD = jsc.parallelize(
+ *    Arrays.asList(
+ *      RowFactory.create(0, "Hi I heard about Spark", 3.0),
+ *      RowFactory.create(1, "I wish Java could use case classes", 4.0),
+ *      RowFactory.create(2, "Logistic regression models are neat", 4.0)));
+ *  DataFrame df = jsql.createDataFrame(rowRDD, schema);
+ *  // define feature transformers
+ *  RegexTokenizer tok = new RegexTokenizer()
+ *    .setInputCol("text")
+ *    .setOutputCol("words");
+ *  StopWordsRemover sw = new StopWordsRemover()
+ *    .setInputCol("words")
+ *    .setOutputCol("filtered_words");
+ *  HashingTF tf = new HashingTF()
+ *    .setInputCol("filtered_words")
+ *    .setOutputCol("tf")
+ *    .setNumFeatures(10000);
+ *  IDF idf = new IDF()
+ *    .setInputCol("tf")
+ *    .setOutputCol("tf_idf");
+ *  VectorAssembler assembler = new VectorAssembler()
+ *    .setInputCols(new String[] {"tf_idf", "rating"})
+ *    .setOutputCol("features");
+ *
+ *  // assemble and fit the feature transformation pipeline
+ *  Pipeline pipeline = new Pipeline()
+ *    .setStages(new PipelineStage[] {tok, sw, tf, idf, assembler});
+ *  PipelineModel model = pipeline.fit(df);
+ *
+ *  // save transformed features with raw data
+ *  model.transform(df)
+ *    .select("id", "text", "rating", "features")
+ *    .write().format("parquet").save("/output/path");
+ * 
+ * 
+ * + * Some feature transformers implemented in MLlib are inspired by those implemented in scikit-learn. + * The major difference is that most scikit-learn feature transformers operate eagerly on the entire + * input dataset, while MLlib's feature transformers operate lazily on individual columns, + * which is more efficient and flexible to handle large and complex datasets. + * + * @see + * scikit-learn.preprocessing + */ +package org.apache.spark.ml.feature; From 2a508df20d03b3d4a3c05b65fb02d849bc080ef9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 17 Sep 2015 09:21:21 -0700 Subject: [PATCH 324/802] [SPARK-10459] [SQL] Do not need to have ConvertToSafe for PythonUDF JIRA: https://issues.apache.org/jira/browse/SPARK-10459 As mentioned in the JIRA, `PythonUDF` actually could process `UnsafeRow`. Specially, the rows in `childResults` in `BatchPythonEvaluation` will be projected to a `MutableRow`. So I think we can enable `canProcessUnsafeRows` for `BatchPythonEvaluation` and get rid of redundant `ConvertToSafe`. Author: Liang-Chi Hsieh Closes #8616 from viirya/pyudf-unsafe. --- .../scala/org/apache/spark/sql/execution/pythonUDFs.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 5a58d846ad80b..d0411da6fdf5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -337,6 +337,10 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: def children: Seq[SparkPlan] = child :: Nil + override def outputsUnsafeRows: Boolean = false + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + protected override def doExecute(): RDD[InternalRow] = { val childResults = child.execute().map(_.copy()) From c88bb5df94f9696677c3a429472114bc66f32a52 Mon Sep 17 00:00:00 2001 From: "yangping.wu" Date: Thu, 17 Sep 2015 09:52:40 -0700 Subject: [PATCH 325/802] [SPARK-10660] Doc describe error in the "Running Spark on YARN" page MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In the Configuration section, the **spark.yarn.driver.memoryOverhead** and **spark.yarn.am.memoryOverhead**‘s default value should be "driverMemory * 0.10, with minimum of 384" and "AM memory * 0.10, with minimum of 384" respectively. Because from Spark 1.4.0, the **MEMORY_OVERHEAD_FACTOR** is set to 0.1.0, not 0.07. Author: yangping.wu Closes #8797 from 397090770/SparkOnYarnDocError. --- docs/running-on-yarn.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index d1244323edfff..3a961d245f3de 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -211,14 +211,14 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.driver.memoryOverhead - driverMemory * 0.07, with minimum of 384 + driverMemory * 0.10, with minimum of 384 The amount of off heap memory (in megabytes) to be allocated per driver in cluster mode. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%). spark.yarn.am.memoryOverhead - AM memory * 0.07, with minimum of 384 + AM memory * 0.10, with minimum of 384 Same as spark.yarn.driver.memoryOverhead, but for the Application Master in client mode. From 136c77d8bbf48f7c45dd7c3fbe261a0476f455fe Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 17 Sep 2015 10:02:15 -0700 Subject: [PATCH 326/802] [SPARK-10642] [PYSPARK] Fix crash when calling rdd.lookup() on tuple keys JIRA: https://issues.apache.org/jira/browse/SPARK-10642 When calling `rdd.lookup()` on a RDD with tuple keys, `portable_hash` will return a long. That causes `DAGScheduler.submitJob` to throw `java.lang.ClassCastException: java.lang.Long cannot be cast to java.lang.Integer`. Author: Liang-Chi Hsieh Closes #8796 from viirya/fix-pyrdd-lookup. --- python/pyspark/rdd.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 9ef60a7e2c84b..ab5aab1e115f7 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -84,7 +84,7 @@ def portable_hash(x): h ^= len(x) if h == -1: h = -2 - return h + return int(h) return hash(x) @@ -2192,6 +2192,9 @@ def lookup(self, key): [42] >>> sorted.lookup(1024) [] + >>> rdd2 = sc.parallelize([(('a', 'b'), 'c')]).groupByKey() + >>> list(rdd2.lookup(('a', 'b'))[0]) + ['c'] """ values = self.filter(lambda kv: kv[0] == key).values() From 81b4db374dd61b6f1c30511c70b6ab2a52c68faa Mon Sep 17 00:00:00 2001 From: Josiah Samuel Date: Thu, 17 Sep 2015 10:18:21 -0700 Subject: [PATCH 327/802] [SPARK-10172] [CORE] disable sort in HistoryServer webUI This pull request is to address the JIRA SPARK-10172 (History Server web UI gets messed up when sorting on any column). The content of the table gets messed up due to the rowspan attribute of the table data(cell) during sorting. The current table sort library used in SparkUI (sorttable.js) doesn't support/handle cells(td) with rowspans. The fix will disable the table sort in the web UI, when there are jobs listed with multiple attempts. Author: Josiah Samuel Closes #8506 from josiahsams/SPARK-10172. --- .../scala/org/apache/spark/deploy/history/HistoryPage.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 0830cc1ba1245..b347cb3be69f7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -51,7 +51,10 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val hasMultipleAttempts = appsToShow.exists(_.attempts.size > 1) val appTable = if (hasMultipleAttempts) { - UIUtils.listingTable(appWithAttemptHeader, appWithAttemptRow, appsToShow) + // Sorting is disable here as table sort on rowspan has issues. + // ref. SPARK-10172 + UIUtils.listingTable(appWithAttemptHeader, appWithAttemptRow, + appsToShow, sortable = false) } else { UIUtils.listingTable(appHeader, appRow, appsToShow) } From 36d8b278d82e788bf583e8438fac524d0023311d Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Thu, 17 Sep 2015 10:25:18 -0700 Subject: [PATCH 328/802] [SPARK-10531] [CORE] AppId is set as AppName in status rest api Verify it manually. Author: Jeff Zhang Closes #8688 from zjffdu/SPARK-10531. --- .../main/scala/org/apache/spark/SparkContext.scala | 1 + .../spark/deploy/history/FsHistoryProvider.scala | 9 ++++----- .../scala/org/apache/spark/deploy/master/Master.scala | 2 +- core/src/main/scala/org/apache/spark/ui/SparkUI.scala | 11 ++++++----- .../scala/org/apache/spark/ui/UISeleniumSuite.scala | 2 +- 5 files changed, 13 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index a2f34eafa2c38..9c3218719f7fc 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -521,6 +521,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _applicationId = _taskScheduler.applicationId() _applicationAttemptId = taskScheduler.applicationAttemptId() _conf.set("spark.app.id", _applicationId) + _ui.foreach(_.setAppId(_applicationId)) _env.blockManager.initialize(_applicationId) // The metrics system for Driver need to be set spark.app.id to app ID. diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index a5755eac36396..8eb2ba1e8683b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -146,16 +146,15 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val ui = { val conf = this.conf.clone() val appSecManager = new SecurityManager(conf) - SparkUI.createHistoryUI(conf, replayBus, appSecManager, appId, + SparkUI.createHistoryUI(conf, replayBus, appSecManager, appInfo.name, HistoryServer.getAttemptURI(appId, attempt.attemptId), attempt.startTime) // Do not call ui.bind() to avoid creating a new server for each application } val appListener = new ApplicationEventListener() replayBus.addListener(appListener) - val appInfo = replay(fs.getFileStatus(new Path(logDir, attempt.logPath)), replayBus) - appInfo.map { info => - ui.setAppName(s"${info.name} ($appId)") - + val appAttemptInfo = replay(fs.getFileStatus(new Path(logDir, attempt.logPath)), + replayBus) + appAttemptInfo.map { info => val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) ui.getSecurityManager.setAcls(uiAclsEnabled) // make sure to set admin acls before view acls so they are properly picked up diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 26904d39a9bec..d518e92133aad 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -944,7 +944,7 @@ private[deploy] class Master( val logInput = EventLoggingListener.openEventLog(new Path(eventLogFile), fs) val replayBus = new ReplayListenerBus() val ui = SparkUI.createHistoryUI(new SparkConf, replayBus, new SecurityManager(conf), - appName + status, HistoryServer.UI_PATH_PREFIX + s"/${app.id}", app.startTime) + appName, HistoryServer.UI_PATH_PREFIX + s"/${app.id}", app.startTime) val maybeTruncated = eventLogFile.endsWith(EventLoggingListener.IN_PROGRESS) try { replayBus.replay(logInput, eventLogFile, maybeTruncated) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index d8b90568b7b9a..99085ada9f0af 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -56,6 +56,8 @@ private[spark] class SparkUI private ( val stagesTab = new StagesTab(this) + var appId: String = _ + /** Initialize all components of the server. */ def initialize() { attachTab(new JobsTab(this)) @@ -75,9 +77,8 @@ private[spark] class SparkUI private ( def getAppName: String = appName - /** Set the app name for this UI. */ - def setAppName(name: String) { - appName = name + def setAppId(id: String): Unit = { + appId = id } /** Stop the server behind this web interface. Only valid after bind(). */ @@ -94,12 +95,12 @@ private[spark] class SparkUI private ( private[spark] def appUIAddress = s"http://$appUIHostPort" def getSparkUI(appId: String): Option[SparkUI] = { - if (appId == appName) Some(this) else None + if (appId == this.appId) Some(this) else None } def getApplicationInfoList: Iterator[ApplicationInfo] = { Iterator(new ApplicationInfo( - id = appName, + id = appId, name = appName, attempts = Seq(new ApplicationAttemptInfo( attemptId = None, diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 22e30ecaf0533..18eec7da9763e 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -658,6 +658,6 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } def apiUrl(ui: SparkUI, path: String): URL = { - new URL(ui.appUIAddress + "/api/v1/applications/test/" + path) + new URL(ui.appUIAddress + "/api/v1/applications/" + ui.sc.get.applicationId + "/" + path) } } From e0dc2bc232206d2f4da4278502c1f88babc8b55a Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 17 Sep 2015 11:05:30 -0700 Subject: [PATCH 329/802] [SPARK-10650] Clean before building docs The [published docs for 1.5.0](http://spark.apache.org/docs/1.5.0/api/java/org/apache/spark/streaming/) have a bunch of test classes in them. The only way I can reproduce this is to `test:compile` before running `unidoc`. To prevent this from happening again, I've added a clean before doc generation. Author: Michael Armbrust Closes #8787 from marmbrus/testsInDocs. --- docs/_plugins/copy_api_dirs.rb | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 15ceda11a8a80..01718d98dffe0 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -26,12 +26,15 @@ curr_dir = pwd cd("..") - puts "Running 'build/sbt -Pkinesis-asl compile unidoc' from " + pwd + "; this may take a few minutes..." - puts `build/sbt -Pkinesis-asl compile unidoc` + puts "Running 'build/sbt -Pkinesis-asl clean compile unidoc' from " + pwd + "; this may take a few minutes..." + puts `build/sbt -Pkinesis-asl clean compile unidoc` puts "Moving back into docs dir." cd("docs") + puts "Removing old docs" + puts `rm -rf api` + # Copy over the unified ScalaDoc for all projects to api/scala. # This directory will be copied over to _site when `jekyll` command is run. source = "../target/scala-2.10/unidoc" From aad644fbe29151aec9004817d42e4928bdb326f3 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 17 Sep 2015 11:14:52 -0700 Subject: [PATCH 330/802] [SPARK-10639] [SQL] Need to convert UDAF's result from scala to sql type https://issues.apache.org/jira/browse/SPARK-10639 Author: Yin Huai Closes #8788 from yhuai/udafConversion. --- .../sql/catalyst/CatalystTypeConverters.scala | 7 +- .../spark/sql/RandomDataGenerator.scala | 16 ++- .../spark/sql/execution/aggregate/udaf.scala | 37 +++++- .../org/apache/spark/sql/QueryTest.scala | 21 ++-- .../spark/sql/UserDefinedTypeSuite.scala | 11 ++ .../execution/AggregationQuerySuite.scala | 108 +++++++++++++++++- 6 files changed, 188 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 966623ed017ba..f25591794abdb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -138,8 +138,13 @@ object CatalystTypeConverters { private case class UDTConverter( udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] { + // toCatalyst (it calls toCatalystImpl) will do null check. override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue) - override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue) + + override def toScala(catalystValue: Any): Any = { + if (catalystValue == null) null else udt.deserialize(catalystValue) + } + override def toScalaImpl(row: InternalRow, column: Int): Any = toScala(row.get(column, udt.sqlType)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 4025cbcec1019..e48395028e399 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -108,7 +108,21 @@ object RandomDataGenerator { arr }) case BooleanType => Some(() => rand.nextBoolean()) - case DateType => Some(() => new java.sql.Date(rand.nextInt())) + case DateType => + val generator = + () => { + var milliseconds = rand.nextLong() % 253402329599999L + // -62135740800000L is the number of milliseconds before January 1, 1970, 00:00:00 GMT + // for "0001-01-01 00:00:00.000000". We need to find a + // number that is greater or equals to this number as a valid timestamp value. + while (milliseconds < -62135740800000L) { + // 253402329599999L is the the number of milliseconds since + // January 1, 1970, 00:00:00 GMT for "9999-12-31 23:59:59.999999". + milliseconds = rand.nextLong() % 253402329599999L + } + DateTimeUtils.toJavaDate((milliseconds / DateTimeUtils.MILLIS_PER_DAY).toInt) + } + Some(generator) case TimestampType => val generator = () => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index d43d3dd9ffaae..1114fe6552bdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -40,6 +40,9 @@ sealed trait BufferSetterGetterUtils { var i = 0 while (i < getters.length) { getters(i) = dataTypes(i) match { + case NullType => + (row: InternalRow, ordinal: Int) => null + case BooleanType => (row: InternalRow, ordinal: Int) => if (row.isNullAt(ordinal)) null else row.getBoolean(ordinal) @@ -74,6 +77,14 @@ sealed trait BufferSetterGetterUtils { (row: InternalRow, ordinal: Int) => if (row.isNullAt(ordinal)) null else row.getDecimal(ordinal, precision, scale) + case DateType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getInt(ordinal) + + case TimestampType => + (row: InternalRow, ordinal: Int) => + if (row.isNullAt(ordinal)) null else row.getLong(ordinal) + case other => (row: InternalRow, ordinal: Int) => if (row.isNullAt(ordinal)) null else row.get(ordinal, other) @@ -92,6 +103,9 @@ sealed trait BufferSetterGetterUtils { var i = 0 while (i < setters.length) { setters(i) = dataTypes(i) match { + case NullType => + (row: MutableRow, ordinal: Int, value: Any) => row.setNullAt(ordinal) + case b: BooleanType => (row: MutableRow, ordinal: Int, value: Any) => if (value != null) { @@ -150,9 +164,23 @@ sealed trait BufferSetterGetterUtils { case dt: DecimalType => val precision = dt.precision + (row: MutableRow, ordinal: Int, value: Any) => + // To make it work with UnsafeRow, we cannot use setNullAt. + // Please see the comment of UnsafeRow's setDecimal. + row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision) + + case DateType => (row: MutableRow, ordinal: Int, value: Any) => if (value != null) { - row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision) + row.setInt(ordinal, value.asInstanceOf[Int]) + } else { + row.setNullAt(ordinal) + } + + case TimestampType => + (row: MutableRow, ordinal: Int, value: Any) => + if (value != null) { + row.setLong(ordinal, value.asInstanceOf[Long]) } else { row.setNullAt(ordinal) } @@ -205,6 +233,7 @@ private[sql] class MutableAggregationBufferImpl ( throw new IllegalArgumentException( s"Could not access ${i}th value in this buffer because it only has $length values.") } + toScalaConverters(i)(bufferValueGetters(i)(underlyingBuffer, offsets(i))) } @@ -352,6 +381,10 @@ private[sql] case class ScalaUDAF( } } + private[this] lazy val outputToCatalystConverter: Any => Any = { + CatalystTypeConverters.createToCatalystConverter(dataType) + } + // This buffer is only used at executor side. private[this] var inputAggregateBuffer: InputAggregationBuffer = null @@ -424,7 +457,7 @@ private[sql] case class ScalaUDAF( override def eval(buffer: InternalRow): Any = { evalAggregateBuffer.underlyingInputBuffer = buffer - udaf.evaluate(evalAggregateBuffer) + outputToCatalystConverter(udaf.evaluate(evalAggregateBuffer)) } override def toString: String = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index cada03e9ac6bb..e3c5a426671d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -115,19 +115,26 @@ object QueryTest { */ def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = { val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty + + // We need to call prepareRow recursively to handle schemas with struct types. + def prepareRow(row: Row): Row = { + Row.fromSeq(row.toSeq.map { + case null => null + case d: java.math.BigDecimal => BigDecimal(d) + // Convert array to Seq for easy equality check. + case b: Array[_] => b.toSeq + case r: Row => prepareRow(r) + case o => o + }) + } + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to // Java's java.math.BigDecimal.compareTo). // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for // equality test. - val converted: Seq[Row] = answer.map { s => - Row.fromSeq(s.toSeq.map { - case d: java.math.BigDecimal => BigDecimal(d) - case b: Array[Byte] => b.toSeq - case o => o - }) - } + val converted: Seq[Row] = answer.map(prepareRow) if (!isSorted) converted.sortBy(_.toString()) else converted } val sparkAnswer = try df.collect().toSeq catch { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 46d87843dfa4d..7992fd59ff4ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -22,6 +22,7 @@ import scala.beans.{BeanInfo, BeanProperty} import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -163,4 +164,14 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext { assert(new MyDenseVectorUDT().typeName === "mydensevector") assert(new OpenHashSetUDT(IntegerType).typeName === "openhashset") } + + test("Catalyst type converter null handling for UDTs") { + val udt = new MyDenseVectorUDT() + val toScalaConverter = CatalystTypeConverters.createToScalaConverter(udt) + assert(toScalaConverter(null) === null) + + val toCatalystConverter = CatalystTypeConverters.createToCatalystConverter(udt) + assert(toCatalystConverter(null) === null) + + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index a73b1bd52c09f..24b1846923c77 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -17,13 +17,55 @@ package org.apache.spark.sql.hive.execution +import scala.collection.JavaConverters._ + import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} import org.apache.spark.sql.hive.test.TestHiveSingleton +class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction { + + def inputSchema: StructType = schema + + def bufferSchema: StructType = schema + + def dataType: DataType = schema + + def deterministic: Boolean = true + + def initialize(buffer: MutableAggregationBuffer): Unit = { + (0 until schema.length).foreach { i => + buffer.update(i, null) + } + } + + def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + if (!input.isNullAt(0) && input.getInt(0) == 50) { + (0 until schema.length).foreach { i => + buffer.update(i, input.get(i)) + } + } + } + + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { + if (!buffer2.isNullAt(0) && buffer2.getInt(0) == 50) { + (0 until schema.length).foreach { i => + buffer1.update(i, buffer2.get(i)) + } + } + } + + def evaluate(buffer: Row): Any = { + Row.fromSeq(buffer.toSeq) + } +} + abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import testImplicits._ @@ -508,6 +550,70 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) } } + + test("udaf with all data types") { + val struct = + StructType( + StructField("f1", FloatType, true) :: + StructField("f2", ArrayType(BooleanType), true) :: Nil) + val dataTypes = Seq(StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), MapType(StringType, LongType), struct, + new MyDenseVectorUDT()) + // Right now, we will use SortBasedAggregate to handle UDAFs. + // UnsafeRow.mutableFieldTypes.asScala.toSeq will trigger SortBasedAggregate to use + // UnsafeRow as the aggregation buffer. While, dataTypes will trigger + // SortBasedAggregate to use a safe row as the aggregation buffer. + Seq(dataTypes, UnsafeRow.mutableFieldTypes.asScala.toSeq).foreach { dataTypes => + val fields = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, nullable = true) + } + // The schema used for data generator. + val schemaForGenerator = StructType(fields) + // The schema used for the DataFrame df. + val schema = StructType(StructField("id", IntegerType) +: fields) + + logInfo(s"Testing schema: ${schema.treeString}") + + val udaf = new ScalaAggregateFunction(schema) + // Generate data at the driver side. We need to materialize the data first and then + // create RDD. + val maybeDataGenerator = + RandomDataGenerator.forType( + dataType = schemaForGenerator, + nullable = true, + seed = Some(System.nanoTime())) + val dataGenerator = + maybeDataGenerator + .getOrElse(fail(s"Failed to create data generator for schema $schemaForGenerator")) + val data = (1 to 50).map { i => + dataGenerator.apply() match { + case row: Row => Row.fromSeq(i +: row.toSeq) + case null => Row.fromSeq(i +: Seq.fill(schemaForGenerator.length)(null)) + case other => + fail(s"Row or null is expected to be generated, " + + s"but a ${other.getClass.getCanonicalName} is generated.") + } + } + + // Create a DF for the schema with random data. + val rdd = sqlContext.sparkContext.parallelize(data, 1) + val df = sqlContext.createDataFrame(rdd, schema) + + val allColumns = df.schema.fields.map(f => col(f.name)) + val expectedAnaswer = + data + .find(r => r.getInt(0) == 50) + .getOrElse(fail("A row with id 50 should be the expected answer.")) + checkAnswer( + df.groupBy().agg(udaf(allColumns: _*)), + // udaf returns a Row as the output value. + Row(expectedAnaswer) + ) + } + } } class SortBasedAggregationQuerySuite extends AggregationQuerySuite { From 64743870f23bffb8d96dcc8a0181c1452782a151 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 17 Sep 2015 11:24:38 -0700 Subject: [PATCH 331/802] [SPARK-10394] [ML] Make GBTParams use shared stepSize ```GBTParams``` has ```stepSize``` as learning rate currently. ML has shared param class ```HasStepSize```, ```GBTParams``` can extend from it rather than duplicated implementation. Author: Yanbo Liang Closes #8552 from yanboliang/spark-10394. --- .../org/apache/spark/ml/tree/treeParams.scala | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index d29f5253c9c3f..42e74ce6d2c69 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.tree import org.apache.spark.ml.classification.ClassifierParams import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasMaxIter, HasSeed, HasThresholds} +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} @@ -365,17 +365,7 @@ private[ml] object RandomForestParams { * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { - - /** - * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each - * estimator. - * (default = 0.1) - * @group param - */ - final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size (a.k.a." + - " learning rate) in interval (0, 1] for shrinking the contribution of each estimator", - ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) +private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize { /* TODO: Add this doc when we add this param. SPARK-7132 * Threshold for stopping early when runWithValidation is used. @@ -393,11 +383,19 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { /** @group setParam */ def setMaxIter(value: Int): this.type = set(maxIter, value) - /** @group setParam */ + /** + * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each + * estimator. + * (default = 0.1) + * @group setParam + */ def setStepSize(value: Double): this.type = set(stepSize, value) - /** @group getParam */ - final def getStepSize: Double = $(stepSize) + override def validateParams(): Unit = { + require(ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)( + getStepSize), "GBT parameter stepSize should be in interval (0, 1], " + + s"but it given invalid value $getStepSize.") + } /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */ private[ml] def getOldBoostingStrategy( From f1c911552cf5d0d60831c79c1881016293aec66c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 17 Sep 2015 11:40:24 -0700 Subject: [PATCH 332/802] [SPARK-10657] Remove SCP-based Jenkins log archiving As of https://issues.apache.org/jira/browse/SPARK-7561, we no longer need to use our custom SCP-based mechanism for archiving Jenkins logs on the master machine; this has been superseded by the use of a Jenkins plugin which archives the logs and provides public links to view them. Per shaneknapp, we should remove this log syncing mechanism if it is no longer necessary; removing the need to SCP from the Jenkins workers to the masters is a desired step as part of some larger Jenkins infra refactoring. Author: Josh Rosen Closes #8793 from JoshRosen/remove-jenkins-ssh-to-master. --- dev/run-tests-jenkins | 35 ----------------------------------- 1 file changed, 35 deletions(-) diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 3be78575e70f1..d3b05fa6df0ce 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -116,39 +116,6 @@ function post_message () { fi } -function send_archived_logs () { - echo "Archiving unit tests logs..." - - local log_files=$( - find .\ - -name "unit-tests.log" -o\ - -path "./sql/hive/target/HiveCompatibilitySuite.failed" -o\ - -path "./sql/hive/target/HiveCompatibilitySuite.hiveFailed" -o\ - -path "./sql/hive/target/HiveCompatibilitySuite.wrong" - ) - - if [ -z "$log_files" ]; then - echo "> No log files found." >&2 - else - local log_archive="unit-tests-logs.tar.gz" - echo "$log_files" | xargs tar czf ${log_archive} - - local jenkins_build_dir=${JENKINS_HOME}/jobs/${JOB_NAME}/builds/${BUILD_NUMBER} - local scp_output=$(scp ${log_archive} amp-jenkins-master:${jenkins_build_dir}/${log_archive}) - local scp_status="$?" - - if [ "$scp_status" -ne 0 ]; then - echo "Failed to send archived unit tests logs to Jenkins master." >&2 - echo "> scp_status: ${scp_status}" >&2 - echo "> scp_output: ${scp_output}" >&2 - else - echo "> Send successful." - fi - - rm -f ${log_archive} - fi -} - # post start message { start_message="\ @@ -244,8 +211,6 @@ done test_result_note=" * This patch **fails $failing_test**." fi - - send_archived_logs } # post end message From 4fbf3328692e876f39ea78494510f9d9c5a53f15 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 17 Sep 2015 14:09:06 -0700 Subject: [PATCH 333/802] [SPARK-9698] [ML] Add RInteraction transformer for supporting R-style feature interactions This is a pre-req for supporting the ":" operator in the RFormula feature transformer. Design doc from umbrella task: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit mengxr Author: Eric Liang Closes #7987 from ericl/interaction. --- .../apache/spark/ml/feature/Interaction.scala | 278 ++++++++++++++++++ .../spark/ml/feature/InteractionSuite.scala | 165 +++++++++++ 2 files changed, 443 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala new file mode 100644 index 0000000000000..9194763fb32f5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable.ArrayBuilder + +import org.apache.spark.SparkException +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.Transformer +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** + * :: Experimental :: + * Implements the feature interaction transform. This transformer takes in Double and Vector type + * columns and outputs a flattened vector of their feature interactions. To handle interaction, + * we first one-hot encode any nominal features. Then, a vector of the feature cross-products is + * produced. + * + * For example, given the input feature values `Double(2)` and `Vector(3, 4)`, the output would be + * `Vector(6, 8)` if all input features were numeric. If the first feature was instead nominal + * with four categories, the output would then be `Vector(0, 0, 0, 0, 3, 4, 0, 0)`. + */ +@Experimental +class Interaction(override val uid: String) extends Transformer + with HasInputCols with HasOutputCol { + + def this() = this(Identifiable.randomUID("interaction")) + + /** @group setParam */ + def setInputCols(values: Array[String]): this.type = set(inputCols, values) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + // optimistic schema; does not contain any ML attributes + override def transformSchema(schema: StructType): StructType = { + validateParams() + StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false)) + } + + override def transform(dataset: DataFrame): DataFrame = { + validateParams() + val inputFeatures = $(inputCols).map(c => dataset.schema(c)) + val featureEncoders = getFeatureEncoders(inputFeatures) + val featureAttrs = getFeatureAttrs(inputFeatures) + + def interactFunc = udf { row: Row => + var indices = ArrayBuilder.make[Int] + var values = ArrayBuilder.make[Double] + var size = 1 + indices += 0 + values += 1.0 + var featureIndex = row.length - 1 + while (featureIndex >= 0) { + val prevIndices = indices.result() + val prevValues = values.result() + val prevSize = size + val currentEncoder = featureEncoders(featureIndex) + indices = ArrayBuilder.make[Int] + values = ArrayBuilder.make[Double] + size *= currentEncoder.outputSize + currentEncoder.foreachNonzeroOutput(row(featureIndex), (i, a) => { + var j = 0 + while (j < prevIndices.length) { + indices += prevIndices(j) + i * prevSize + values += prevValues(j) * a + j += 1 + } + }) + featureIndex -= 1 + } + Vectors.sparse(size, indices.result(), values.result()).compressed + } + + val featureCols = inputFeatures.map { f => + f.dataType match { + case DoubleType => dataset(f.name) + case _: VectorUDT => dataset(f.name) + case _: NumericType | BooleanType => dataset(f.name).cast(DoubleType) + } + } + dataset.select( + col("*"), + interactFunc(struct(featureCols: _*)).as($(outputCol), featureAttrs.toMetadata())) + } + + /** + * Creates a feature encoder for each input column, which supports efficient iteration over + * one-hot encoded feature values. See also the class-level comment of [[FeatureEncoder]]. + * + * @param features The input feature columns to create encoders for. + */ + private def getFeatureEncoders(features: Seq[StructField]): Array[FeatureEncoder] = { + def getNumFeatures(attr: Attribute): Int = { + attr match { + case nominal: NominalAttribute => + math.max(1, nominal.getNumValues.getOrElse( + throw new SparkException("Nominal features must have attr numValues defined."))) + case _ => + 1 // numeric feature + } + } + features.map { f => + val numFeatures = f.dataType match { + case _: NumericType | BooleanType => + Array(getNumFeatures(Attribute.fromStructField(f))) + case _: VectorUDT => + val attrs = AttributeGroup.fromStructField(f).attributes.getOrElse( + throw new SparkException("Vector attributes must be defined for interaction.")) + attrs.map(getNumFeatures).toArray + } + new FeatureEncoder(numFeatures) + }.toArray + } + + /** + * Generates ML attributes for the output vector of all feature interactions. We make a best + * effort to generate reasonable names for output features, based on the concatenation of the + * interacting feature names and values delimited with `_`. When no feature name is specified, + * we fall back to using the feature index (e.g. `foo:bar_2_0` may indicate an interaction + * between the numeric `foo` feature and a nominal third feature from column `bar`. + * + * @param features The input feature columns to the Interaction transformer. + */ + private def getFeatureAttrs(features: Seq[StructField]): AttributeGroup = { + var featureAttrs: Seq[Attribute] = Nil + features.reverse.foreach { f => + val encodedAttrs = f.dataType match { + case _: NumericType | BooleanType => + val attr = Attribute.fromStructField(f) + encodedFeatureAttrs(Seq(attr), None) + case _: VectorUDT => + val group = AttributeGroup.fromStructField(f) + encodedFeatureAttrs(group.attributes.get, Some(group.name)) + } + if (featureAttrs.isEmpty) { + featureAttrs = encodedAttrs + } else { + featureAttrs = encodedAttrs.flatMap { head => + featureAttrs.map { tail => + NumericAttribute.defaultAttr.withName(head.name.get + ":" + tail.name.get) + } + } + } + } + new AttributeGroup($(outputCol), featureAttrs.toArray) + } + + /** + * Generates the output ML attributes for a single input feature. Each output feature name has + * up to three parts: the group name, feature name, and category name (for nominal features), + * each separated by an underscore. + * + * @param inputAttrs The attributes of the input feature. + * @param groupName Optional name of the input feature group (for Vector type features). + */ + private def encodedFeatureAttrs( + inputAttrs: Seq[Attribute], + groupName: Option[String]): Seq[Attribute] = { + + def format( + index: Int, + attrName: Option[String], + categoryName: Option[String]): String = { + val parts = Seq(groupName, Some(attrName.getOrElse(index.toString)), categoryName) + parts.flatten.mkString("_") + } + + inputAttrs.zipWithIndex.flatMap { + case (nominal: NominalAttribute, i) => + if (nominal.values.isDefined) { + nominal.values.get.map( + v => BinaryAttribute.defaultAttr.withName(format(i, nominal.name, Some(v)))) + } else { + Array.tabulate(nominal.getNumValues.get)( + j => BinaryAttribute.defaultAttr.withName(format(i, nominal.name, Some(j.toString)))) + } + case (a: Attribute, i) => + Seq(NumericAttribute.defaultAttr.withName(format(i, a.name, None))) + } + } + + override def copy(extra: ParamMap): Interaction = defaultCopy(extra) + + override def validateParams(): Unit = { + require(get(inputCols).isDefined, "Input cols must be defined first.") + require(get(outputCol).isDefined, "Output col must be defined first.") + require($(inputCols).length > 0, "Input cols must have non-zero length.") + require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") + } +} + +/** + * This class performs on-the-fly one-hot encoding of features as you iterate over them. To + * indicate which input features should be one-hot encoded, an array of the feature counts + * must be passed in ahead of time. + * + * @param numFeatures Array of feature counts for each input feature. For nominal features this + * count is equal to the number of categories. For numeric features the count + * should be set to 1. + */ +private[ml] class FeatureEncoder(numFeatures: Array[Int]) { + assert(numFeatures.forall(_ > 0), "Features counts must all be positive.") + + /** The size of the output vector. */ + val outputSize = numFeatures.sum + + /** Precomputed offsets for the location of each output feature. */ + private val outputOffsets = { + val arr = new Array[Int](numFeatures.length) + var i = 1 + while (i < arr.length) { + arr(i) = arr(i - 1) + numFeatures(i - 1) + i += 1 + } + arr + } + + /** + * Given an input row of features, invokes the specific function for every non-zero output. + * + * @param value The row value to encode, either a Double or Vector. + * @param f The callback to invoke on each non-zero (index, value) output pair. + */ + def foreachNonzeroOutput(value: Any, f: (Int, Double) => Unit): Unit = value match { + case d: Double => + assert(numFeatures.length == 1, "DoubleType columns should only contain one feature.") + val numOutputCols = numFeatures.head + if (numOutputCols > 1) { + assert( + d >= 0.0 && d == d.toInt && d < numOutputCols, + s"Values from column must be indices, but got $d.") + f(d.toInt, 1.0) + } else { + f(0, d) + } + case vec: Vector => + assert(numFeatures.length == vec.size, + s"Vector column size was ${vec.size}, expected ${numFeatures.length}") + vec.foreachActive { (i, v) => + val numOutputCols = numFeatures(i) + if (numOutputCols > 1) { + assert( + v >= 0.0 && v == v.toInt && v < numOutputCols, + s"Values from column must be indices, but got $v.") + f(outputOffsets(i) + v.toInt, 1.0) + } else { + f(outputOffsets(i), v) + } + } + case null => + throw new SparkException("Values to interact cannot be null.") + case o => + throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala new file mode 100644 index 0000000000000..2beb62ca08233 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable.ArrayBuilder + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.functions.col + +class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new Interaction()) + } + + test("feature encoder") { + def encode(cardinalities: Array[Int], value: Any): Vector = { + var indices = ArrayBuilder.make[Int] + var values = ArrayBuilder.make[Double] + val encoder = new FeatureEncoder(cardinalities) + encoder.foreachNonzeroOutput(value, (i, v) => { + indices += i + values += v + }) + Vectors.sparse(encoder.outputSize, indices.result(), values.result()).compressed + } + assert(encode(Array(1), 2.2) === Vectors.dense(2.2)) + assert(encode(Array(3), Vectors.dense(1)) === Vectors.dense(0, 1, 0)) + assert(encode(Array(1, 1), Vectors.dense(1.1, 2.2)) === Vectors.dense(1.1, 2.2)) + assert(encode(Array(3, 1), Vectors.dense(1, 2.2)) === Vectors.dense(0, 1, 0, 2.2)) + assert(encode(Array(2, 1), Vectors.dense(1, 2.2)) === Vectors.dense(0, 1, 2.2)) + assert(encode(Array(2, 1, 1), Vectors.dense(0, 2.2, 0)) === Vectors.dense(1, 0, 2.2, 0)) + intercept[SparkException] { encode(Array(1), "foo") } + intercept[SparkException] { encode(Array(1), null) } + intercept[AssertionError] { encode(Array(2), 2.2) } + intercept[AssertionError] { encode(Array(3), Vectors.dense(2.2)) } + intercept[AssertionError] { encode(Array(1), Vectors.dense(1.0, 2.0, 3.0)) } + intercept[AssertionError] { encode(Array(3), Vectors.dense(-1)) } + intercept[AssertionError] { encode(Array(3), Vectors.dense(3)) } + } + + test("numeric interaction") { + val data = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0)), + (1, Vectors.dense(1.0, 5.0))) + ).toDF("a", "b") + val groupAttr = new AttributeGroup( + "b", + Array[Attribute]( + NumericAttribute.defaultAttr.withName("foo"), + NumericAttribute.defaultAttr.withName("bar"))) + val df = data.select( + col("a").as("a", NumericAttribute.defaultAttr.toMetadata()), + col("b").as("b", groupAttr.toMetadata())) + val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") + val res = trans.transform(df) + val expected = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0))) + ).toDF("a", "b", "features") + assert(res.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(res.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a:b_foo"), Some(1)), + new NumericAttribute(Some("a:b_bar"), Some(2)))) + assert(attrs === expectedAttrs) + } + + test("nominal interaction") { + val data = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0)), + (1, Vectors.dense(1.0, 5.0))) + ).toDF("a", "b") + val groupAttr = new AttributeGroup( + "b", + Array[Attribute]( + NumericAttribute.defaultAttr.withName("foo"), + NumericAttribute.defaultAttr.withName("bar"))) + val df = data.select( + col("a").as( + "a", NominalAttribute.defaultAttr.withValues(Array("up", "down", "left")).toMetadata()), + col("b").as("b", groupAttr.toMetadata())) + val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") + val res = trans.transform(df) + val expected = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0))) + ).toDF("a", "b", "features") + assert(res.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(res.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a_up:b_foo"), Some(1)), + new NumericAttribute(Some("a_up:b_bar"), Some(2)), + new NumericAttribute(Some("a_down:b_foo"), Some(3)), + new NumericAttribute(Some("a_down:b_bar"), Some(4)), + new NumericAttribute(Some("a_left:b_foo"), Some(5)), + new NumericAttribute(Some("a_left:b_bar"), Some(6)))) + assert(attrs === expectedAttrs) + } + + test("default attr names") { + val data = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(0.0, 4.0), 1.0), + (1, Vectors.dense(1.0, 5.0), 10.0)) + ).toDF("a", "b", "c") + val groupAttr = new AttributeGroup( + "b", + Array[Attribute]( + NominalAttribute.defaultAttr.withNumValues(2), + NumericAttribute.defaultAttr)) + val df = data.select( + col("a").as("a", NominalAttribute.defaultAttr.withNumValues(3).toMetadata()), + col("b").as("b", groupAttr.toMetadata()), + col("c").as("c", NumericAttribute.defaultAttr.toMetadata())) + val trans = new Interaction().setInputCols(Array("a", "b", "c")).setOutputCol("features") + val res = trans.transform(df) + val expected = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(0.0, 4.0), 1.0, Vectors.dense(0, 0, 0, 0, 0, 0, 1, 0, 4)), + (1, Vectors.dense(1.0, 5.0), 10.0, Vectors.dense(0, 0, 0, 0, 10, 50, 0, 0, 0))) + ).toDF("a", "b", "c", "features") + assert(res.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(res.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a_0:b_0_0:c"), Some(1)), + new NumericAttribute(Some("a_0:b_0_1:c"), Some(2)), + new NumericAttribute(Some("a_0:b_1:c"), Some(3)), + new NumericAttribute(Some("a_1:b_0_0:c"), Some(4)), + new NumericAttribute(Some("a_1:b_0_1:c"), Some(5)), + new NumericAttribute(Some("a_1:b_1:c"), Some(6)), + new NumericAttribute(Some("a_2:b_0_0:c"), Some(7)), + new NumericAttribute(Some("a_2:b_0_1:c"), Some(8)), + new NumericAttribute(Some("a_2:b_1:c"), Some(9)))) + assert(attrs === expectedAttrs) + } +} From 0f5ef6dfa67a068606aff8ea9d1addfce73446eb Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 17 Sep 2015 19:16:34 -0700 Subject: [PATCH 334/802] [SPARK-10674] [TESTS] Increase timeouts in SaslIntegrationSuite. 1s seems to trigger too many times on the jenkins build boxes, so increase the timeout and cross fingers. Author: Marcelo Vanzin Closes #8802 from vanzin/SPARK-10674 and squashes the following commits: 3c93117 [Marcelo Vanzin] Use java 7 syntax. d667d1b [Marcelo Vanzin] [SPARK-10674] [tests] Increase timeouts in SaslIntegrationSuite. --- .../spark/network/sasl/SaslIntegrationSuite.java | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 5cb0e4d4a6458..c393a5e1e6810 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -56,6 +56,11 @@ import org.apache.spark.network.util.TransportConf; public class SaslIntegrationSuite { + + // Use a long timeout to account for slow / overloaded build machines. In the normal case, + // tests should finish way before the timeout expires. + private final static long TIMEOUT_MS = 10_000; + static TransportServer server; static TransportConf conf; static TransportContext context; @@ -102,7 +107,7 @@ public void testGoodClient() throws IOException { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); String msg = "Hello, World!"; - byte[] resp = client.sendRpcSync(msg.getBytes(), 1000); + byte[] resp = client.sendRpcSync(msg.getBytes(), TIMEOUT_MS); assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg } @@ -131,7 +136,7 @@ public void testNoSaslClient() throws IOException { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); try { - client.sendRpcSync(new byte[13], 1000); + client.sendRpcSync(new byte[13], TIMEOUT_MS); fail("Should have failed"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage")); @@ -139,7 +144,7 @@ public void testNoSaslClient() throws IOException { try { // Guessing the right tag byte doesn't magically get you in... - client.sendRpcSync(new byte[] { (byte) 0xEA }, 1000); + client.sendRpcSync(new byte[] { (byte) 0xEA }, TIMEOUT_MS); fail("Should have failed"); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException")); @@ -217,12 +222,12 @@ public synchronized void onBlockFetchFailure(String blockId, Throwable t) { new String[] { System.getProperty("java.io.tmpdir") }, 1, "org.apache.spark.shuffle.sort.SortShuffleManager"); RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo); - client1.sendRpcSync(regmsg.toByteArray(), 10000); + client1.sendRpcSync(regmsg.toByteArray(), TIMEOUT_MS); // Make a successful request to fetch blocks, which creates a new stream. But do not actually // fetch any blocks, to keep the stream open. OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds); - byte[] response = client1.sendRpcSync(openMessage.toByteArray(), 10000); + byte[] response = client1.sendRpcSync(openMessage.toByteArray(), TIMEOUT_MS); StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response); long streamId = stream.streamId; From 98f1ea67da1b0e3aa791c3cbfa06e48e2ba0d75b Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 17 Sep 2015 21:37:10 -0700 Subject: [PATCH 335/802] [SPARK-8518] [ML] Log-linear models for survival analysis [Accelerated Failure Time (AFT) model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) is the most commonly used and easy to parallel method of survival analysis for censored survival data. It is the log-linear model based on the Weibull distribution of the survival time. Users can refer to the R function [```survreg```](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html) to compare the model and [```predict```](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/predict.survreg.html) to compare the prediction. There are different kinds of model prediction, I have just select the type ```response``` which is default used for R. Author: Yanbo Liang Closes #8611 from yanboliang/spark-8518. --- .../ml/regression/AFTSurvivalRegression.scala | 449 ++++++++++++++++++ .../AFTSurvivalRegressionSuite.scala | 311 ++++++++++++ 2 files changed, 760 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala new file mode 100644 index 0000000000000..5b25db651f56c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -0,0 +1,449 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import scala.collection.mutable + +import breeze.linalg.{DenseVector => BDV} +import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS} + +import org.apache.spark.{SparkException, Logging} +import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.ml.{Model, Estimator} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.linalg.BLAS +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.storage.StorageLevel + +/** + * Params for accelerated failure time (AFT) regression. + */ +private[regression] trait AFTSurvivalRegressionParams extends Params + with HasFeaturesCol with HasLabelCol with HasPredictionCol with HasMaxIter + with HasTol with HasFitIntercept { + + /** + * Param for censor column name. + * The value of this column could be 0 or 1. + * If the value is 1, it means the event has occurred i.e. uncensored; otherwise censored. + * @group param + */ + @Since("1.6.0") + final val censorCol: Param[String] = new Param(this, "censorCol", "censor column name") + + /** @group getParam */ + @Since("1.6.0") + def getCensorCol: String = $(censorCol) + setDefault(censorCol -> "censor") + + /** + * Param for quantile probabilities array. + * Values of the quantile probabilities array should be in the range [0, 1]. + * @group param + */ + @Since("1.6.0") + final val quantileProbabilities: DoubleArrayParam = new DoubleArrayParam(this, + "quantileProbabilities", "quantile probabilities array", + (t: Array[Double]) => t.forall(ParamValidators.inRange(0, 1))) + + /** @group getParam */ + @Since("1.6.0") + def getQuantileProbabilities: Array[Double] = $(quantileProbabilities) + + /** Checks whether the input has quantile probabilities array. */ + protected[regression] def hasQuantileProbabilities: Boolean = { + isDefined(quantileProbabilities) && $(quantileProbabilities).size != 0 + } + + /** + * Validates and transforms the input schema with the provided param map. + * @param schema input schema + * @param fitting whether this is in fitting or prediction + * @return output schema + */ + protected def validateAndTransformSchema( + schema: StructType, + fitting: Boolean): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + if (fitting) { + SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType) + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + } + SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) + } +} + +/** + * :: Experimental :: + * Fit a parametric survival regression model named accelerated failure time (AFT) model + * ([[https://en.wikipedia.org/wiki/Accelerated_failure_time_model]]) + * based on the Weibull distribution of the survival time. + */ +@Experimental +@Since("1.6.0") +class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: String) + extends Estimator[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with Logging { + + @Since("1.6.0") + def this() = this(Identifiable.randomUID("aftSurvReg")) + + /** @group setParam */ + @Since("1.6.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setLabelCol(value: String): this.type = set(labelCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setCensorCol(value: String): this.type = set(censorCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** + * Set if we should fit the intercept + * Default is true. + * @group setParam + */ + @Since("1.6.0") + def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) + setDefault(fitIntercept -> true) + + /** + * Set the maximum number of iterations. + * Default is 100. + * @group setParam + */ + @Since("1.6.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + setDefault(maxIter -> 100) + + /** + * Set the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-6. + * @group setParam + */ + @Since("1.6.0") + def setTol(value: Double): this.type = set(tol, value) + setDefault(tol -> 1E-6) + + /** + * Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset, + * and put it in an RDD with strong types. + */ + protected[ml] def extractAFTPoints(dataset: DataFrame): RDD[AFTPoint] = { + dataset.select($(featuresCol), $(labelCol), $(censorCol)).map { + case Row(features: Vector, label: Double, censor: Double) => + AFTPoint(features, label, censor) + } + } + + @Since("1.6.0") + override def fit(dataset: DataFrame): AFTSurvivalRegressionModel = { + validateAndTransformSchema(dataset.schema, fitting = true) + val instances = extractAFTPoints(dataset) + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + + val costFun = new AFTCostFun(instances, $(fitIntercept)) + val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) + + val numFeatures = dataset.select($(featuresCol)).take(1)(0).getAs[Vector](0).size + /* + The weights vector has three parts: + the first element: Double, log(sigma), the log of scale parameter + the second element: Double, intercept of the beta parameter + the third to the end elements: Doubles, regression coefficients vector of the beta parameter + */ + val initialWeights = Vectors.zeros(numFeatures + 2) + + val states = optimizer.iterations(new CachedDiffFunction(costFun), + initialWeights.toBreeze.toDenseVector) + + val weights = { + val arrayBuilder = mutable.ArrayBuilder.make[Double] + var state: optimizer.State = null + while (states.hasNext) { + state = states.next() + arrayBuilder += state.adjustedValue + } + if (state == null) { + val msg = s"${optimizer.getClass.getName} failed." + throw new SparkException(msg) + } + + state.x.toArray.clone() + } + + if (handlePersistence) instances.unpersist() + + val coefficients = Vectors.dense(weights.slice(2, weights.length)) + val intercept = weights(1) + val scale = math.exp(weights(0)) + val model = new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale) + copyValues(model.setParent(this)) + } + + @Since("1.6.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = true) + } + + @Since("1.6.0") + override def copy(extra: ParamMap): AFTSurvivalRegression = defaultCopy(extra) +} + +/** + * :: Experimental :: + * Model produced by [[AFTSurvivalRegression]]. + */ +@Experimental +@Since("1.6.0") +class AFTSurvivalRegressionModel private[ml] ( + @Since("1.6.0") override val uid: String, + @Since("1.6.0") val coefficients: Vector, + @Since("1.6.0") val intercept: Double, + @Since("1.6.0") val scale: Double) + extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams { + + /** @group setParam */ + @Since("1.6.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + @Since("1.6.0") + def setQuantileProbabilities(value: Array[Double]): this.type = set(quantileProbabilities, value) + + @Since("1.6.0") + def predictQuantiles(features: Vector): Vector = { + require(hasQuantileProbabilities, + "AFTSurvivalRegressionModel predictQuantiles must set quantile probabilities array") + // scale parameter for the Weibull distribution of lifetime + val lambda = math.exp(BLAS.dot(coefficients, features) + intercept) + // shape parameter for the Weibull distribution of lifetime + val k = 1 / scale + val quantiles = $(quantileProbabilities).map { + q => lambda * math.exp(math.log(-math.log(1 - q)) / k) + } + Vectors.dense(quantiles) + } + + @Since("1.6.0") + def predict(features: Vector): Double = { + math.exp(BLAS.dot(coefficients, features) + intercept) + } + + @Since("1.6.0") + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema) + val predictUDF = udf { features: Vector => predict(features) } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + + @Since("1.6.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = false) + } + + @Since("1.6.0") + override def copy(extra: ParamMap): AFTSurvivalRegressionModel = { + copyValues(new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale), extra) + .setParent(parent) + } +} + +/** + * AFTAggregator computes the gradient and loss for a AFT loss function, + * as used in AFT survival regression for samples in sparse or dense vector in a online fashion. + * + * The loss function and likelihood function under the AFT model based on: + * Lawless, J. F., Statistical Models and Methods for Lifetime Data, + * New York: John Wiley & Sons, Inc. 2003. + * + * Two AFTAggregator can be merged together to have a summary of loss and gradient of + * the corresponding joint dataset. + * + * Given the values of the covariates x^{'}, for random lifetime t_{i} of subjects i = 1, ..., n, + * with possible right-censoring, the likelihood function under the AFT model is given as + * {{{ + * L(\beta,\sigma)=\prod_{i=1}^n[\frac{1}{\sigma}f_{0} + * (\frac{\log{t_{i}}-x^{'}\beta}{\sigma})]^{\delta_{i}}S_{0} + * (\frac{\log{t_{i}}-x^{'}\beta}{\sigma})^{1-\delta_{i}} + * }}} + * Where \delta_{i} is the indicator of the event has occurred i.e. uncensored or not. + * Using \epsilon_{i}=\frac{\log{t_{i}}-x^{'}\beta}{\sigma}, the log-likelihood function + * assumes the form + * {{{ + * \iota(\beta,\sigma)=\sum_{i=1}^{n}[-\delta_{i}\log\sigma+ + * \delta_{i}\log{f_{0}}(\epsilon_{i})+(1-\delta_{i})\log{S_{0}(\epsilon_{i})}] + * }}} + * Where S_{0}(\epsilon_{i}) is the baseline survivor function, + * and f_{0}(\epsilon_{i}) is corresponding density function. + * + * The most commonly used log-linear survival regression method is based on the Weibull + * distribution of the survival time. The Weibull distribution for lifetime corresponding + * to extreme value distribution for log of the lifetime, + * and the S_{0}(\epsilon) function is + * {{{ + * S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}}) + * }}} + * the f_{0}(\epsilon_{i}) function is + * {{{ + * f_{0}(\epsilon_{i})=e^{\epsilon_{i}}\exp(-e^{\epsilon_{i}}) + * }}} + * The log-likelihood function for Weibull distribution of lifetime is + * {{{ + * \iota(\beta,\sigma)= + * -\sum_{i=1}^n[\delta_{i}\log\sigma-\delta_{i}\epsilon_{i}+e^{\epsilon_{i}}] + * }}} + * Due to minimizing the negative log-likelihood equivalent to maximum a posteriori probability, + * the loss function we use to optimize is -\iota(\beta,\sigma). + * The gradient functions for \beta and \log\sigma respectively are + * {{{ + * \frac{\partial (-\iota)}{\partial \beta}= + * \sum_{1=1}^{n}[\delta_{i}-e^{\epsilon_{i}}]\frac{x_{i}}{\sigma} + * }}} + * {{{ + * \frac{\partial (-\iota)}{\partial (\log\sigma)}= + * \sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}] + * }}} + * @param weights The log of scale parameter, the intercept and + * regression coefficients corresponding to the features. + * @param fitIntercept Whether to fit an intercept term. + */ +private class AFTAggregator(weights: BDV[Double], fitIntercept: Boolean) + extends Serializable { + + // beta is the intercept and regression coefficients to the covariates + private val beta = weights.slice(1, weights.length) + // sigma is the scale parameter of the AFT model + private val sigma = math.exp(weights(0)) + + private var totalCnt: Long = 0L + private var lossSum = 0.0 + private var gradientBetaSum = BDV.zeros[Double](beta.length) + private var gradientLogSigmaSum = 0.0 + + def count: Long = totalCnt + + def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt + + // Here we optimize loss function over beta and log(sigma) + def gradient: BDV[Double] = BDV.vertcat(BDV(Array(gradientLogSigmaSum / totalCnt.toDouble)), + gradientBetaSum/totalCnt.toDouble) + + /** + * Add a new training data to this AFTAggregator, and update the loss and gradient + * of the objective function. + * + * @param data The AFTPoint representation for one data point to be added into this aggregator. + * @return This AFTAggregator object. + */ + def add(data: AFTPoint): this.type = { + + // TODO: Don't create a new xi vector each time. + val xi = if (fitIntercept) { + Vectors.dense(Array(1.0) ++ data.features.toArray).toBreeze + } else { + Vectors.dense(Array(0.0) ++ data.features.toArray).toBreeze + } + val ti = data.label + val delta = data.censor + val epsilon = (math.log(ti) - beta.dot(xi)) / sigma + + lossSum += math.log(sigma) * delta + lossSum += (math.exp(epsilon) - delta * epsilon) + + // Sanity check (should never occur): + assert(!lossSum.isInfinity, + s"AFTAggregator loss sum is infinity. Error for unknown reason.") + + gradientBetaSum += xi * (delta - math.exp(epsilon)) / sigma + gradientLogSigmaSum += delta + (delta - math.exp(epsilon)) * epsilon + + totalCnt += 1 + this + } + + /** + * Merge another AFTAggregator, and update the loss and gradient + * of the objective function. + * (Note that it's in place merging; as a result, `this` object will be modified.) + * + * @param other The other AFTAggregator to be merged. + * @return This AFTAggregator object. + */ + def merge(other: AFTAggregator): this.type = { + if (totalCnt != 0) { + totalCnt += other.totalCnt + lossSum += other.lossSum + + gradientBetaSum += other.gradientBetaSum + gradientLogSigmaSum += other.gradientLogSigmaSum + } + this + } +} + +/** + * AFTCostFun implements Breeze's DiffFunction[T] for AFT cost. + * It returns the loss and gradient at a particular point (coefficients). + * It's used in Breeze's convex optimization routines. + */ +private class AFTCostFun(data: RDD[AFTPoint], fitIntercept: Boolean) + extends DiffFunction[BDV[Double]] { + + override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { + + val aftAggregator = data.treeAggregate(new AFTAggregator(coefficients, fitIntercept))( + seqOp = (c, v) => (c, v) match { + case (aggregator, instance) => aggregator.add(instance) + }, + combOp = (c1, c2) => (c1, c2) match { + case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) + }) + + (aftAggregator.loss, aftAggregator.gradient) + } +} + +/** + * Class that represents the (features, label, censor) of a data point. + * + * @param features List of features for this data point. + * @param label Label for this data point. + * @param censor Indicator of the event has occurred or not. If the value is 1, it means + * the event has occurred i.e. uncensored; otherwise censored. + */ +private[regression] case class AFTPoint(features: Vector, label: Double, censor: Double) { + require(censor == 1.0 || censor == 0.0, "censor of class AFTPoint must be 1.0 or 0.0") +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala new file mode 100644 index 0000000000000..ca7140a45ea65 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.linalg.{DenseVector, Vectors} +import org.apache.spark.mllib.linalg.BLAS +import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{Row, DataFrame} + +class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { + + @transient var datasetUnivariate: DataFrame = _ + @transient var datasetMultivariate: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + datasetUnivariate = sqlContext.createDataFrame( + sc.parallelize(generateAFTInput( + 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0))) + datasetMultivariate = sqlContext.createDataFrame( + sc.parallelize(generateAFTInput( + 2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0))) + } + + test("params") { + ParamsSuite.checkParams(new AFTSurvivalRegression) + val model = new AFTSurvivalRegressionModel("aftSurvReg", Vectors.dense(0.0), 0.0, 0.0) + ParamsSuite.checkParams(model) + } + + test("aft survival regression: default params") { + val aftr = new AFTSurvivalRegression + assert(aftr.getLabelCol === "label") + assert(aftr.getFeaturesCol === "features") + assert(aftr.getPredictionCol === "prediction") + assert(aftr.getCensorCol === "censor") + assert(aftr.getFitIntercept) + assert(aftr.getMaxIter === 100) + assert(aftr.getTol === 1E-6) + val model = aftr.fit(datasetUnivariate) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + + model.transform(datasetUnivariate) + .select("label", "prediction") + .collect() + assert(model.getFeaturesCol === "features") + assert(model.getPredictionCol === "prediction") + assert(model.intercept !== 0.0) + assert(model.hasParent) + } + + def generateAFTInput( + numFeatures: Int, + xMean: Array[Double], + xVariance: Array[Double], + nPoints: Int, + seed: Int, + weibullShape: Double, + weibullScale: Double, + exponentialMean: Double): Seq[AFTPoint] = { + + def censor(x: Double, y: Double): Double = { if (x <= y) 1.0 else 0.0 } + + val weibull = new WeibullGenerator(weibullShape, weibullScale) + weibull.setSeed(seed) + + val exponential = new ExponentialGenerator(exponentialMean) + exponential.setSeed(seed) + + val rnd = new Random(seed) + val x = Array.fill[Array[Double]](nPoints)(Array.fill[Double](numFeatures)(rnd.nextDouble())) + + x.foreach { v => + var i = 0 + val len = v.length + while (i < len) { + v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) + i += 1 + } + } + val y = (1 to nPoints).map { i => (weibull.nextValue(), exponential.nextValue()) } + + y.zip(x).map { p => AFTPoint(Vectors.dense(p._2), p._1._1, censor(p._1._1, p._1._2)) } + } + + test("aft survival regression with univariate") { + val trainer = new AFTSurvivalRegression + val model = trainer.fit(datasetUnivariate) + + /* + Using the following R code to load the data and train the model using survival package. + + library("survival") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + features <- data$V1 + censor <- data$V2 + label <- data$V3 + sr.fit <- survreg(Surv(label, censor) ~ features, dist='weibull') + summary(sr.fit) + + Value Std. Error z p + (Intercept) 1.759 0.4141 4.247 2.16e-05 + features -0.039 0.0735 -0.531 5.96e-01 + Log(scale) 0.344 0.0379 9.073 1.16e-19 + + Scale= 1.41 + + Weibull distribution + Loglik(model)= -1152.2 Loglik(intercept only)= -1152.3 + Chisq= 0.28 on 1 degrees of freedom, p= 0.6 + Number of Newton-Raphson Iterations: 5 + n= 1000 + */ + val coefficientsR = Vectors.dense(-0.039) + val interceptR = 1.759 + val scaleR = 1.41 + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.coefficients ~== coefficientsR relTol 1E-3) + assert(model.scale ~== scaleR relTol 1E-3) + + /* + Using the following R code to predict. + + testdata <- list(features=6.559282795753792) + responsePredict <- predict(sr.fit, newdata=testdata) + responsePredict + + 1 + 4.494763 + + quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9)) + quantilePredict + + [1] 0.1879174 2.6801195 14.5779394 + */ + val features = Vectors.dense(6.559282795753792) + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val responsePredictR = 4.494763 + val quantilePredictR = Vectors.dense(0.1879174, 2.6801195, 14.5779394) + + assert(model.predict(features) ~== responsePredictR relTol 1E-3) + model.setQuantileProbabilities(quantileProbabilities) + assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) + + model.transform(datasetUnivariate).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("aft survival regression with multivariate") { + val trainer = new AFTSurvivalRegression + val model = trainer.fit(datasetMultivariate) + + /* + Using the following R code to load the data and train the model using survival package. + + library("survival") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + feature1 <- data$V1 + feature2 <- data$V2 + censor <- data$V3 + label <- data$V4 + sr.fit <- survreg(Surv(label, censor) ~ feature1 + feature2, dist='weibull') + summary(sr.fit) + + Value Std. Error z p + (Intercept) 1.9206 0.1057 18.171 8.78e-74 + feature1 -0.0844 0.0611 -1.381 1.67e-01 + feature2 0.0677 0.0468 1.447 1.48e-01 + Log(scale) -0.0236 0.0436 -0.542 5.88e-01 + + Scale= 0.977 + + Weibull distribution + Loglik(model)= -1070.7 Loglik(intercept only)= -1072.7 + Chisq= 3.91 on 2 degrees of freedom, p= 0.14 + Number of Newton-Raphson Iterations: 5 + n= 1000 + */ + val coefficientsR = Vectors.dense(-0.0844, 0.0677) + val interceptR = 1.9206 + val scaleR = 0.977 + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.coefficients ~== coefficientsR relTol 1E-3) + assert(model.scale ~== scaleR relTol 1E-3) + + /* + Using the following R code to predict. + testdata <- list(feature1=2.233396950271428, feature2=-2.5321374085997683) + responsePredict <- predict(sr.fit, newdata=testdata) + responsePredict + + 1 + 4.761219 + + quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9)) + quantilePredict + + [1] 0.5287044 3.3285858 10.7517072 + */ + val features = Vectors.dense(2.233396950271428, -2.5321374085997683) + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val responsePredictR = 4.761219 + val quantilePredictR = Vectors.dense(0.5287044, 3.3285858, 10.7517072) + + assert(model.predict(features) ~== responsePredictR relTol 1E-3) + model.setQuantileProbabilities(quantileProbabilities) + assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) + + model.transform(datasetMultivariate).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("aft survival regression w/o intercept") { + val trainer = new AFTSurvivalRegression().setFitIntercept(false) + val model = trainer.fit(datasetMultivariate) + + /* + Using the following R code to load the data and train the model using survival package. + + library("survival") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + feature1 <- data$V1 + feature2 <- data$V2 + censor <- data$V3 + label <- data$V4 + sr.fit <- survreg(Surv(label, censor) ~ feature1 + feature2 - 1, dist='weibull') + summary(sr.fit) + + Value Std. Error z p + feature1 0.896 0.0685 13.1 3.93e-39 + feature2 -0.709 0.0522 -13.6 5.78e-42 + Log(scale) 0.420 0.0401 10.5 1.23e-25 + + Scale= 1.52 + + Weibull distribution + Loglik(model)= -1292.4 Loglik(intercept only)= -1072.7 + Chisq= -439.57 on 1 degrees of freedom, p= 1 + Number of Newton-Raphson Iterations: 6 + n= 1000 + */ + val coefficientsR = Vectors.dense(0.896, -0.709) + val interceptR = 0.0 + val scaleR = 1.52 + + assert(model.intercept === interceptR) + assert(model.coefficients ~== coefficientsR relTol 1E-3) + assert(model.scale ~== scaleR relTol 1E-3) + + /* + Using the following R code to predict. + testdata <- list(feature1=2.233396950271428, feature2=-2.5321374085997683) + responsePredict <- predict(sr.fit, newdata=testdata) + responsePredict + + 1 + 44.54465 + + quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9)) + quantilePredict + + [1] 1.452103 25.506077 158.428600 + */ + val features = Vectors.dense(2.233396950271428, -2.5321374085997683) + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val responsePredictR = 44.54465 + val quantilePredictR = Vectors.dense(1.452103, 25.506077, 158.428600) + + assert(model.predict(features) ~== responsePredictR relTol 1E-3) + model.setQuantileProbabilities(quantileProbabilities) + assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) + + model.transform(datasetMultivariate).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } +} From d009da2f5c803f3b7344c96abbfcf3ecef2f5ad2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 17 Sep 2015 22:05:20 -0700 Subject: [PATCH 336/802] [SPARK-10682] [GRAPHX] Remove Bagel test suites. Bagel has been deprecated and we haven't done any changes to it. There is no need to run those tests. This should speed up tests by 1 min. Author: Reynold Xin Closes #8807 from rxin/SPARK-10682. --- bagel/src/test/resources/log4j.properties | 27 ----- .../org/apache/spark/bagel/BagelSuite.scala | 113 ------------------ 2 files changed, 140 deletions(-) delete mode 100644 bagel/src/test/resources/log4j.properties delete mode 100644 bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala diff --git a/bagel/src/test/resources/log4j.properties b/bagel/src/test/resources/log4j.properties deleted file mode 100644 index edbecdae92096..0000000000000 --- a/bagel/src/test/resources/log4j.properties +++ /dev/null @@ -1,27 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=INFO, file -log4j.appender.file=org.apache.log4j.FileAppender -log4j.appender.file.append=true -log4j.appender.file.file=target/unit-tests.log -log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n - -# Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala deleted file mode 100644 index fb10d734ac74b..0000000000000 --- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.bagel - -import org.scalatest.{BeforeAndAfter, Assertions} -import org.scalatest.concurrent.Timeouts -import org.scalatest.time.SpanSugar._ - -import org.apache.spark._ -import org.apache.spark.storage.StorageLevel - -class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable -class TestMessage(val targetId: String) extends Message[String] with Serializable - -class BagelSuite extends SparkFunSuite with Assertions with BeforeAndAfter with Timeouts { - - var sc: SparkContext = _ - - after { - if (sc != null) { - sc.stop() - sc = null - } - } - - test("halting by voting") { - sc = new SparkContext("local", "test") - val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 5 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - - test("halting by message silence") { - sc = new SparkContext("local", "test") - val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(false, 0)))) - val msgs = sc.parallelize(Array("a" -> new TestMessage("a"))) - val numSupersteps = 5 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - val msgsOut = - msgs match { - case Some(ms) if (superstep < numSupersteps - 1) => - ms - case _ => - Array[TestMessage]() - } - (new TestVertex(self.active, self.age + 1), msgsOut) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - - test("large number of iterations") { - // This tests whether jobs with a large number of iterations finish in a reasonable time, - // because non-memoized recursion in RDD or DAGScheduler used to cause them to hang - failAfter(30 seconds) { - sc = new SparkContext("local", "test") - val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 50 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - } - - test("using non-default persistence level") { - failAfter(10 seconds) { - sc = new SparkContext("local", "test") - val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) - val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 20 - val result = - Bagel.run(sc, verts, msgs, sc.defaultParallelism, StorageLevel.DISK_ONLY) { - (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => - (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) - } - for ((id, vert) <- result.collect) { - assert(vert.age === numSupersteps) - } - } - } -} From 93c7650ab60a839a9cbe8b4ea1d5eda93e53ebe0 Mon Sep 17 00:00:00 2001 From: linweizhong Date: Thu, 17 Sep 2015 22:25:24 -0700 Subject: [PATCH 337/802] [SPARK-9522] [SQL] SparkSubmit process can not exit if kill application when HiveThriftServer was starting When we start HiveThriftServer, we will start SparkContext first, then start HiveServer2, if we kill application while HiveServer2 is starting then SparkContext will stop successfully, but SparkSubmit process can not exit. Author: linweizhong Closes #7853 from Sephiroth-Lin/SPARK-9522. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 2 +- .../spark/sql/hive/thriftserver/HiveThriftServer2.scala | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 9c3218719f7fc..ebd8e946ee7a2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -97,7 +97,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val startTime = System.currentTimeMillis() - private val stopped: AtomicBoolean = new AtomicBoolean(false) + private[spark] val stopped: AtomicBoolean = new AtomicBoolean(false) private def assertNotStopped(): Unit = { if (stopped.get()) { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index dd9fef9206d0b..a0643cec0fb7c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -93,6 +93,12 @@ object HiveThriftServer2 extends Logging { } else { None } + // If application was killed before HiveThriftServer2 start successfully then SparkSubmit + // process can not exit, so check whether if SparkContext was stopped. + if (SparkSQLEnv.sparkContext.stopped.get()) { + logError("SparkContext has stopped even if HiveServer2 has started, so exit") + System.exit(-1) + } } catch { case e: Exception => logError("Error starting HiveThriftServer2", e) From 9a56dcdf7f19c9f7f913a2ce9bc981cb43a113c5 Mon Sep 17 00:00:00 2001 From: Felix Bechstein Date: Thu, 17 Sep 2015 22:42:46 -0700 Subject: [PATCH 338/802] docs/running-on-mesos.md: state default values in default column This PR simply uses the default value column for defaults. Author: Felix Bechstein Closes #8810 from felixb/fix_mesos_doc. --- docs/running-on-mesos.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 247e6ecfbdb86..1814fb32ed8a5 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -332,21 +332,21 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.principal - Framework principal to authenticate to Mesos + (none) Set the principal with which Spark framework will use to authenticate with Mesos. spark.mesos.secret - Framework secret to authenticate to Mesos + (none)/td> Set the secret with which Spark framework will use to authenticate with Mesos. spark.mesos.role - Role for the Spark framework + * Set the role of this Spark framework for Mesos. Roles are used in Mesos for reservations and resource weight sharing. @@ -354,7 +354,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.constraints - Attribute based constraints to be matched against when accepting resource offers. + (none) Attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. Refer to Mesos Attributes & Resources for more information on attributes.
    From 74d8f7dda82c3a16348f3ff22da83203e0b7f708 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 17 Sep 2015 22:46:13 -0700 Subject: [PATCH 339/802] Added tag to documentation. --- docs/running-on-mesos.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 1814fb32ed8a5..330c159c67bca 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -346,7 +346,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.role - * + * Set the role of this Spark framework for Mesos. Roles are used in Mesos for reservations and resource weight sharing. From e3b5d6cb29e0f983fcc55920619e6433298955f5 Mon Sep 17 00:00:00 2001 From: "navis.ryu" Date: Fri, 18 Sep 2015 00:43:02 -0700 Subject: [PATCH 340/802] [SPARK-10684] [SQL] StructType.interpretedOrdering need not to be serialized Kryo fails with buffer overflow even with max value (2G). {noformat} org.apache.spark.SparkException: Kryo serialization failed: Buffer overflow. Available: 0, required: 1 Serialization trace: containsChild (org.apache.spark.sql.catalyst.expressions.BoundReference) child (org.apache.spark.sql.catalyst.expressions.SortOrder) array (scala.collection.mutable.ArraySeq) ordering (org.apache.spark.sql.catalyst.expressions.InterpretedOrdering) interpretedOrdering (org.apache.spark.sql.types.StructType) schema (org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema). To avoid this, increase spark.kryoserializer.buffer.max value. at org.apache.spark.serializer.KryoSerializerInstance.serialize(KryoSerializer.scala:263) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:240) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) at java.lang.Thread.run(Thread.java:745) {noformat} Author: navis.ryu Closes #8808 from navis/SPARK-10684. --- .../main/scala/org/apache/spark/sql/types/StructType.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index d8968ef806390..b29cf22dcb582 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -305,7 +305,9 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru f(this) || fields.exists(field => field.dataType.existsRecursively(f)) } - private[sql] val interpretedOrdering = InterpretedOrdering.forSchema(this.fields.map(_.dataType)) + @transient + private[sql] lazy val interpretedOrdering = + InterpretedOrdering.forSchema(this.fields.map(_.dataType)) } object StructType extends AbstractDataType { From 20fd35dfd1ac402b622604e7bbedcc53a580b0a2 Mon Sep 17 00:00:00 2001 From: Yash Datta Date: Fri, 18 Sep 2015 08:22:38 -0700 Subject: [PATCH 341/802] [SPARK-10451] [SQL] Prevent unnecessary serializations in InMemoryColumnarTableScan Many of the fields in InMemoryColumnar scan and InMemoryRelation can be made transient. This reduces my 1000ms job to abt 700 ms . The task size reduces from 2.8 mb to ~1300kb Author: Yash Datta Closes #8604 from saucam/serde. --- .../columnar/InMemoryColumnarTableScan.scala | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 66d429bc06198..d7e145f9c2bb8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -48,10 +48,10 @@ private[sql] case class InMemoryRelation( useCompression: Boolean, batchSize: Int, storageLevel: StorageLevel, - child: SparkPlan, + @transient child: SparkPlan, tableName: Option[String])( - private var _cachedColumnBuffers: RDD[CachedBatch] = null, - private var _statistics: Statistics = null, + @transient private var _cachedColumnBuffers: RDD[CachedBatch] = null, + @transient private var _statistics: Statistics = null, private var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null) extends LogicalPlan with MultiInstanceRelation { @@ -62,7 +62,7 @@ private[sql] case class InMemoryRelation( _batchStats } - val partitionStatistics = new PartitionStatistics(output) + @transient val partitionStatistics = new PartitionStatistics(output) private def computeSizeInBytes = { val sizeOfRow: Expression = @@ -196,7 +196,7 @@ private[sql] case class InMemoryRelation( private[sql] case class InMemoryColumnarTableScan( attributes: Seq[Attribute], predicates: Seq[Expression], - relation: InMemoryRelation) + @transient relation: InMemoryRelation) extends LeafNode { override def output: Seq[Attribute] = attributes @@ -205,7 +205,7 @@ private[sql] case class InMemoryColumnarTableScan( // Returned filter predicate should return false iff it is impossible for the input expression // to evaluate to `true' based on statistics collected about this partition batch. - val buildFilter: PartialFunction[Expression, Expression] = { + @transient val buildFilter: PartialFunction[Expression, Expression] = { case And(lhs: Expression, rhs: Expression) if buildFilter.isDefinedAt(lhs) || buildFilter.isDefinedAt(rhs) => (buildFilter.lift(lhs) ++ buildFilter.lift(rhs)).reduce(_ && _) @@ -268,16 +268,23 @@ private[sql] case class InMemoryColumnarTableScan( readBatches.setValue(0) } - relation.cachedColumnBuffers.mapPartitions { cachedBatchIterator => - val partitionFilter = newPredicate( - partitionFilters.reduceOption(And).getOrElse(Literal(true)), - relation.partitionStatistics.schema) + // Using these variables here to avoid serialization of entire objects (if referenced directly) + // within the map Partitions closure. + val schema = relation.partitionStatistics.schema + val schemaIndex = schema.zipWithIndex + val relOutput = relation.output + val buffers = relation.cachedColumnBuffers + + buffers.mapPartitions { cachedBatchIterator => + val partitionFilter = newPredicate( + partitionFilters.reduceOption(And).getOrElse(Literal(true)), + schema) // Find the ordinals and data types of the requested columns. If none are requested, use the // narrowest (the field with minimum default element size). val (requestedColumnIndices, requestedColumnDataTypes) = if (attributes.isEmpty) { val (narrowestOrdinal, narrowestDataType) = - relation.output.zipWithIndex.map { case (a, ordinal) => + relOutput.zipWithIndex.map { case (a, ordinal) => ordinal -> a.dataType } minBy { case (_, dataType) => ColumnType(dataType).defaultSize @@ -285,7 +292,7 @@ private[sql] case class InMemoryColumnarTableScan( Seq(narrowestOrdinal) -> Seq(narrowestDataType) } else { attributes.map { a => - relation.output.indexWhere(_.exprId == a.exprId) -> a.dataType + relOutput.indexWhere(_.exprId == a.exprId) -> a.dataType }.unzip } @@ -296,7 +303,7 @@ private[sql] case class InMemoryColumnarTableScan( // Build column accessors val columnAccessors = requestedColumnIndices.map { batchColumnIndex => ColumnAccessor( - relation.output(batchColumnIndex).dataType, + relOutput(batchColumnIndex).dataType, ByteBuffer.wrap(cachedBatch.buffers(batchColumnIndex))) } @@ -328,7 +335,7 @@ private[sql] case class InMemoryColumnarTableScan( if (inMemoryPartitionPruningEnabled) { cachedBatchIterator.filter { cachedBatch => if (!partitionFilter(cachedBatch.stats)) { - def statsString: String = relation.partitionStatistics.schema.zipWithIndex.map { + def statsString: String = schemaIndex.map { case (a, i) => val value = cachedBatch.stats.get(i, a.dataType) s"${a.name}: $value" From 35e8ab939000d4a1a01c1af4015c25ff6f4013a3 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 18 Sep 2015 09:53:52 -0700 Subject: [PATCH 342/802] [SPARK-10615] [PYSPARK] change assertEquals to assertEqual As ```assertEquals``` is deprecated, so we need to change ```assertEquals``` to ```assertEqual``` for existing python unit tests. Author: Yanbo Liang Closes #8814 from yanboliang/spark-10615. --- python/pyspark/ml/tests.py | 16 +-- python/pyspark/mllib/tests.py | 162 +++++++++++++++--------------- python/pyspark/sql/tests.py | 18 ++-- python/pyspark/streaming/tests.py | 2 +- 4 files changed, 99 insertions(+), 99 deletions(-) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index b892318f50bd9..648fa8858fba3 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -182,7 +182,7 @@ def test_params(self): self.assertEqual(testParams.getMaxIter(), 10) testParams.setMaxIter(100) self.assertTrue(testParams.isSet(maxIter)) - self.assertEquals(testParams.getMaxIter(), 100) + self.assertEqual(testParams.getMaxIter(), 100) self.assertTrue(testParams.hasParam(inputCol)) self.assertFalse(testParams.hasDefault(inputCol)) @@ -195,7 +195,7 @@ def test_params(self): testParams._setDefault(seed=41) testParams.setSeed(43) - self.assertEquals( + self.assertEqual( testParams.explainParams(), "\n".join(["inputCol: input column name (undefined)", "maxIter: max number of iterations (>= 0) (default: 10, current: 100)", @@ -264,23 +264,23 @@ def test_ngram(self): self.assertEqual(ngram0.getInputCol(), "input") self.assertEqual(ngram0.getOutputCol(), "output") transformedDF = ngram0.transform(dataset) - self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"]) + self.assertEqual(transformedDF.head().output, ["a b c d", "b c d e"]) def test_stopwordsremover(self): sqlContext = SQLContext(self.sc) dataset = sqlContext.createDataFrame([Row(input=["a", "panda"])]) stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output") # Default - self.assertEquals(stopWordRemover.getInputCol(), "input") + self.assertEqual(stopWordRemover.getInputCol(), "input") transformedDF = stopWordRemover.transform(dataset) - self.assertEquals(transformedDF.head().output, ["panda"]) + self.assertEqual(transformedDF.head().output, ["panda"]) # Custom stopwords = ["panda"] stopWordRemover.setStopWords(stopwords) - self.assertEquals(stopWordRemover.getInputCol(), "input") - self.assertEquals(stopWordRemover.getStopWords(), stopwords) + self.assertEqual(stopWordRemover.getInputCol(), "input") + self.assertEqual(stopWordRemover.getStopWords(), stopwords) transformedDF = stopWordRemover.transform(dataset) - self.assertEquals(transformedDF.head().output, ["a"]) + self.assertEqual(transformedDF.head().output, ["a"]) class HasInducedError(Params): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 636f9a06cab7b..96cf13495aa95 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -166,13 +166,13 @@ def test_dot(self): [1., 2., 3., 4.], [1., 2., 3., 4.]]) arr = pyarray.array('d', [0, 1, 2, 3]) - self.assertEquals(10.0, sv.dot(dv)) + self.assertEqual(10.0, sv.dot(dv)) self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat))) - self.assertEquals(30.0, dv.dot(dv)) + self.assertEqual(30.0, dv.dot(dv)) self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat))) - self.assertEquals(30.0, lst.dot(dv)) + self.assertEqual(30.0, lst.dot(dv)) self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat))) - self.assertEquals(7.0, sv.dot(arr)) + self.assertEqual(7.0, sv.dot(arr)) def test_squared_distance(self): sv = SparseVector(4, {1: 1, 3: 2}) @@ -181,27 +181,27 @@ def test_squared_distance(self): lst1 = [4, 3, 2, 1] arr = pyarray.array('d', [0, 2, 1, 3]) narr = array([0, 2, 1, 3]) - self.assertEquals(15.0, _squared_distance(sv, dv)) - self.assertEquals(25.0, _squared_distance(sv, lst)) - self.assertEquals(20.0, _squared_distance(dv, lst)) - self.assertEquals(15.0, _squared_distance(dv, sv)) - self.assertEquals(25.0, _squared_distance(lst, sv)) - self.assertEquals(20.0, _squared_distance(lst, dv)) - self.assertEquals(0.0, _squared_distance(sv, sv)) - self.assertEquals(0.0, _squared_distance(dv, dv)) - self.assertEquals(0.0, _squared_distance(lst, lst)) - self.assertEquals(25.0, _squared_distance(sv, lst1)) - self.assertEquals(3.0, _squared_distance(sv, arr)) - self.assertEquals(3.0, _squared_distance(sv, narr)) + self.assertEqual(15.0, _squared_distance(sv, dv)) + self.assertEqual(25.0, _squared_distance(sv, lst)) + self.assertEqual(20.0, _squared_distance(dv, lst)) + self.assertEqual(15.0, _squared_distance(dv, sv)) + self.assertEqual(25.0, _squared_distance(lst, sv)) + self.assertEqual(20.0, _squared_distance(lst, dv)) + self.assertEqual(0.0, _squared_distance(sv, sv)) + self.assertEqual(0.0, _squared_distance(dv, dv)) + self.assertEqual(0.0, _squared_distance(lst, lst)) + self.assertEqual(25.0, _squared_distance(sv, lst1)) + self.assertEqual(3.0, _squared_distance(sv, arr)) + self.assertEqual(3.0, _squared_distance(sv, narr)) def test_hash(self): v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) v4 = SparseVector(4, [(1, 1.0), (3, 2.5)]) - self.assertEquals(hash(v1), hash(v2)) - self.assertEquals(hash(v1), hash(v3)) - self.assertEquals(hash(v2), hash(v3)) + self.assertEqual(hash(v1), hash(v2)) + self.assertEqual(hash(v1), hash(v3)) + self.assertEqual(hash(v2), hash(v3)) self.assertFalse(hash(v1) == hash(v4)) self.assertFalse(hash(v2) == hash(v4)) @@ -212,8 +212,8 @@ def test_eq(self): v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) - self.assertEquals(v1, v2) - self.assertEquals(v1, v3) + self.assertEqual(v1, v2) + self.assertEqual(v1, v3) self.assertFalse(v2 == v4) self.assertFalse(v1 == v5) self.assertFalse(v1 == v6) @@ -238,13 +238,13 @@ def test_conversion(self): def test_sparse_vector_indexing(self): sv = SparseVector(4, {1: 1, 3: 2}) - self.assertEquals(sv[0], 0.) - self.assertEquals(sv[3], 2.) - self.assertEquals(sv[1], 1.) - self.assertEquals(sv[2], 0.) - self.assertEquals(sv[-1], 2) - self.assertEquals(sv[-2], 0) - self.assertEquals(sv[-4], 0) + self.assertEqual(sv[0], 0.) + self.assertEqual(sv[3], 2.) + self.assertEqual(sv[1], 1.) + self.assertEqual(sv[2], 0.) + self.assertEqual(sv[-1], 2) + self.assertEqual(sv[-2], 0) + self.assertEqual(sv[-4], 0) for ind in [4, -5]: self.assertRaises(ValueError, sv.__getitem__, ind) for ind in [7.8, '1']: @@ -255,7 +255,7 @@ def test_matrix_indexing(self): expected = [[0, 6], [1, 8], [4, 10]] for i in range(3): for j in range(2): - self.assertEquals(mat[i, j], expected[i][j]) + self.assertEqual(mat[i, j], expected[i][j]) def test_repr_dense_matrix(self): mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) @@ -308,11 +308,11 @@ def test_sparse_matrix(self): # Test sparse matrix creation. sm1 = SparseMatrix( 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) - self.assertEquals(sm1.numRows, 3) - self.assertEquals(sm1.numCols, 4) - self.assertEquals(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) - self.assertEquals(sm1.rowIndices.tolist(), [1, 2, 1, 2]) - self.assertEquals(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) + self.assertEqual(sm1.numRows, 3) + self.assertEqual(sm1.numCols, 4) + self.assertEqual(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) + self.assertEqual(sm1.rowIndices.tolist(), [1, 2, 1, 2]) + self.assertEqual(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) self.assertTrue( repr(sm1), 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)') @@ -325,13 +325,13 @@ def test_sparse_matrix(self): for i in range(3): for j in range(4): - self.assertEquals(expected[i][j], sm1[i, j]) + self.assertEqual(expected[i][j], sm1[i, j]) self.assertTrue(array_equal(sm1.toArray(), expected)) # Test conversion to dense and sparse. smnew = sm1.toDense().toSparse() - self.assertEquals(sm1.numRows, smnew.numRows) - self.assertEquals(sm1.numCols, smnew.numCols) + self.assertEqual(sm1.numRows, smnew.numRows) + self.assertEqual(sm1.numCols, smnew.numCols) self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs)) self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices)) self.assertTrue(array_equal(sm1.values, smnew.values)) @@ -339,11 +339,11 @@ def test_sparse_matrix(self): sm1t = SparseMatrix( 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], isTransposed=True) - self.assertEquals(sm1t.numRows, 3) - self.assertEquals(sm1t.numCols, 4) - self.assertEquals(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) - self.assertEquals(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) - self.assertEquals(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) + self.assertEqual(sm1t.numRows, 3) + self.assertEqual(sm1t.numCols, 4) + self.assertEqual(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) + self.assertEqual(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) + self.assertEqual(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) expected = [ [3, 2, 0, 0], @@ -352,18 +352,18 @@ def test_sparse_matrix(self): for i in range(3): for j in range(4): - self.assertEquals(expected[i][j], sm1t[i, j]) + self.assertEqual(expected[i][j], sm1t[i, j]) self.assertTrue(array_equal(sm1t.toArray(), expected)) def test_dense_matrix_is_transposed(self): mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) - self.assertEquals(mat1, mat) + self.assertEqual(mat1, mat) expected = [[0, 4], [1, 6], [3, 9]] for i in range(3): for j in range(2): - self.assertEquals(mat1[i, j], expected[i][j]) + self.assertEqual(mat1[i, j], expected[i][j]) self.assertTrue(array_equal(mat1.toArray(), expected)) sm = mat1.toSparse() @@ -412,8 +412,8 @@ def test_kmeans(self): ] clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||", initializationSteps=7, epsilon=1e-4) - self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1])) - self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3])) + self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1])) + self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3])) def test_kmeans_deterministic(self): from pyspark.mllib.clustering import KMeans @@ -443,8 +443,8 @@ def test_gmm(self): clusters = GaussianMixture.train(data, 2, convergenceTol=0.001, maxIterations=10, seed=56) labels = clusters.predict(data).collect() - self.assertEquals(labels[0], labels[1]) - self.assertEquals(labels[2], labels[3]) + self.assertEqual(labels[0], labels[1]) + self.assertEqual(labels[2], labels[3]) def test_gmm_deterministic(self): from pyspark.mllib.clustering import GaussianMixture @@ -456,7 +456,7 @@ def test_gmm_deterministic(self): clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001, maxIterations=10, seed=63) for c1, c2 in zip(clusters1.weights, clusters2.weights): - self.assertEquals(round(c1, 7), round(c2, 7)) + self.assertEqual(round(c1, 7), round(c2, 7)) def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes @@ -711,18 +711,18 @@ def test_serialize(self): lil[1, 0] = 1 lil[3, 0] = 2 sv = SparseVector(4, {1: 1, 3: 2}) - self.assertEquals(sv, _convert_to_vector(lil)) - self.assertEquals(sv, _convert_to_vector(lil.tocsc())) - self.assertEquals(sv, _convert_to_vector(lil.tocoo())) - self.assertEquals(sv, _convert_to_vector(lil.tocsr())) - self.assertEquals(sv, _convert_to_vector(lil.todok())) + self.assertEqual(sv, _convert_to_vector(lil)) + self.assertEqual(sv, _convert_to_vector(lil.tocsc())) + self.assertEqual(sv, _convert_to_vector(lil.tocoo())) + self.assertEqual(sv, _convert_to_vector(lil.tocsr())) + self.assertEqual(sv, _convert_to_vector(lil.todok())) def serialize(l): return ser.loads(ser.dumps(_convert_to_vector(l))) - self.assertEquals(sv, serialize(lil)) - self.assertEquals(sv, serialize(lil.tocsc())) - self.assertEquals(sv, serialize(lil.tocsr())) - self.assertEquals(sv, serialize(lil.todok())) + self.assertEqual(sv, serialize(lil)) + self.assertEqual(sv, serialize(lil.tocsc())) + self.assertEqual(sv, serialize(lil.tocsr())) + self.assertEqual(sv, serialize(lil.todok())) def test_dot(self): from scipy.sparse import lil_matrix @@ -730,7 +730,7 @@ def test_dot(self): lil[1, 0] = 1 lil[3, 0] = 2 dv = DenseVector(array([1., 2., 3., 4.])) - self.assertEquals(10.0, dv.dot(lil)) + self.assertEqual(10.0, dv.dot(lil)) def test_squared_distance(self): from scipy.sparse import lil_matrix @@ -739,8 +739,8 @@ def test_squared_distance(self): lil[3, 0] = 2 dv = DenseVector(array([1., 2., 3., 4.])) sv = SparseVector(4, {0: 1, 1: 2, 2: 3, 3: 4}) - self.assertEquals(15.0, dv.squared_distance(lil)) - self.assertEquals(15.0, sv.squared_distance(lil)) + self.assertEqual(15.0, dv.squared_distance(lil)) + self.assertEqual(15.0, sv.squared_distance(lil)) def scipy_matrix(self, size, values): """Create a column SciPy matrix from a dictionary of values""" @@ -759,8 +759,8 @@ def test_clustering(self): self.scipy_matrix(3, {2: 1.1}) ] clusters = KMeans.train(self.sc.parallelize(data), 2, initializationMode="k-means||") - self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1])) - self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3])) + self.assertEqual(clusters.predict(data[0]), clusters.predict(data[1])) + self.assertEqual(clusters.predict(data[2]), clusters.predict(data[3])) def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes @@ -984,12 +984,12 @@ def test_word2vec_setters(self): .setNumIterations(10) \ .setSeed(1024) \ .setMinCount(3) - self.assertEquals(model.vectorSize, 2) + self.assertEqual(model.vectorSize, 2) self.assertTrue(model.learningRate < 0.02) - self.assertEquals(model.numPartitions, 2) - self.assertEquals(model.numIterations, 10) - self.assertEquals(model.seed, 1024) - self.assertEquals(model.minCount, 3) + self.assertEqual(model.numPartitions, 2) + self.assertEqual(model.numIterations, 10) + self.assertEqual(model.seed, 1024) + self.assertEqual(model.minCount, 3) def test_word2vec_get_vectors(self): data = [ @@ -1002,7 +1002,7 @@ def test_word2vec_get_vectors(self): ["a"] ] model = Word2Vec().fit(self.sc.parallelize(data)) - self.assertEquals(len(model.getVectors()), 3) + self.assertEqual(len(model.getVectors()), 3) class StandardScalerTests(MLlibTestCase): @@ -1044,8 +1044,8 @@ def test_model_params(self): """Test that the model params are set correctly""" stkm = StreamingKMeans() stkm.setK(5).setDecayFactor(0.0) - self.assertEquals(stkm._k, 5) - self.assertEquals(stkm._decayFactor, 0.0) + self.assertEqual(stkm._k, 5) + self.assertEqual(stkm._decayFactor, 0.0) # Model not set yet. self.assertIsNone(stkm.latestModel()) @@ -1053,9 +1053,9 @@ def test_model_params(self): stkm.setInitialCenters( centers=[[0.0, 0.0], [1.0, 1.0]], weights=[1.0, 1.0]) - self.assertEquals( + self.assertEqual( stkm.latestModel().centers, [[0.0, 0.0], [1.0, 1.0]]) - self.assertEquals(stkm.latestModel().clusterWeights, [1.0, 1.0]) + self.assertEqual(stkm.latestModel().clusterWeights, [1.0, 1.0]) def test_accuracy_for_single_center(self): """Test that parameters obtained are correct for a single center.""" @@ -1070,7 +1070,7 @@ def test_accuracy_for_single_center(self): self.ssc.start() def condition(): - self.assertEquals(stkm.latestModel().clusterWeights, [25.0]) + self.assertEqual(stkm.latestModel().clusterWeights, [25.0]) return True self._eventually(condition, catch_assertions=True) @@ -1114,7 +1114,7 @@ def test_trainOn_model(self): def condition(): finalModel = stkm.latestModel() self.assertTrue(all(finalModel.centers == array(initCenters))) - self.assertEquals(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) + self.assertEqual(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) return True self._eventually(condition, catch_assertions=True) @@ -1141,7 +1141,7 @@ def update(rdd): self.ssc.start() def condition(): - self.assertEquals(result, [[0], [1], [2], [3]]) + self.assertEqual(result, [[0], [1], [2], [3]]) return True self._eventually(condition, catch_assertions=True) @@ -1263,7 +1263,7 @@ def test_convergence(self): self.ssc.start() def condition(): - self.assertEquals(len(models), len(input_batches)) + self.assertEqual(len(models), len(input_batches)) return True # We want all batches to finish for this test. @@ -1297,7 +1297,7 @@ def test_predictions(self): self.ssc.start() def condition(): - self.assertEquals(len(true_predicted), len(input_batches)) + self.assertEqual(len(true_predicted), len(input_batches)) return True self._eventually(condition, catch_assertions=True) @@ -1400,7 +1400,7 @@ def test_parameter_convergence(self): self.ssc.start() def condition(): - self.assertEquals(len(model_weights), len(batches)) + self.assertEqual(len(model_weights), len(batches)) return True # We want all batches to finish for this test. @@ -1433,7 +1433,7 @@ def test_prediction(self): self.ssc.start() def condition(): - self.assertEquals(len(samples), len(batches)) + self.assertEqual(len(samples), len(batches)) return True # We want all batches to finish for this test. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f2172b7a27d88..3e680f1030a71 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -157,7 +157,7 @@ class DataTypeTests(unittest.TestCase): def test_data_type_eq(self): lt = LongType() lt2 = pickle.loads(pickle.dumps(LongType())) - self.assertEquals(lt, lt2) + self.assertEqual(lt, lt2) # regression test for SPARK-7978 def test_decimal_type(self): @@ -393,7 +393,7 @@ def test_infer_nested_schema(self): CustomRow(field1=2, field2="row2"), CustomRow(field1=3, field2="row3")]) df = self.sqlCtx.inferSchema(rdd) - self.assertEquals(Row(field1=1, field2=u'row1'), df.first()) + self.assertEqual(Row(field1=1, field2=u'row1'), df.first()) def test_create_dataframe_from_objects(self): data = [MyObject(1, "1"), MyObject(2, "2")] @@ -403,7 +403,7 @@ def test_create_dataframe_from_objects(self): def test_select_null_literal(self): df = self.sqlCtx.sql("select null as col") - self.assertEquals(Row(col=None), df.first()) + self.assertEqual(Row(col=None), df.first()) def test_apply_schema(self): from datetime import date, datetime @@ -519,14 +519,14 @@ def test_apply_schema_with_udt(self): StructField("point", ExamplePointUDT(), False)]) df = self.sqlCtx.createDataFrame([row], schema) point = df.head().point - self.assertEquals(point, ExamplePoint(1.0, 2.0)) + self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = (1.0, PythonOnlyPoint(1.0, 2.0)) schema = StructType([StructField("label", DoubleType(), False), StructField("point", PythonOnlyUDT(), False)]) df = self.sqlCtx.createDataFrame([row], schema) point = df.head().point - self.assertEquals(point, PythonOnlyPoint(1.0, 2.0)) + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) def test_udf_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT @@ -554,14 +554,14 @@ def test_parquet_with_udt(self): df0.write.parquet(output_dir) df1 = self.sqlCtx.parquetFile(output_dir) point = df1.head().point - self.assertEquals(point, ExamplePoint(1.0, 2.0)) + self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) df0 = self.sqlCtx.createDataFrame([row]) df0.write.parquet(output_dir, mode='overwrite') df1 = self.sqlCtx.parquetFile(output_dir) point = df1.head().point - self.assertEquals(point, PythonOnlyPoint(1.0, 2.0)) + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) def test_column_operators(self): ci = self.df.key @@ -826,8 +826,8 @@ def test_infer_long_type(self): output_dir = os.path.join(self.tempdir.name, "infer_long_type") df.saveAsParquetFile(output_dir) df1 = self.sqlCtx.parquetFile(output_dir) - self.assertEquals('a', df1.first().f1) - self.assertEquals(100000000000000, df1.first().f2) + self.assertEqual('a', df1.first().f1) + self.assertEqual(100000000000000, df1.first().f2) self.assertEqual(_infer_type(1), LongType()) self.assertEqual(_infer_type(2**10), LongType()) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index cfea95b0dec71..e4e56fff3b3fc 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -693,7 +693,7 @@ def check_output(n): # Verify that getActiveOrCreate() returns active context self.setupCalled = False - self.assertEquals(StreamingContext.getActiveOrCreate(self.cpd, setup), self.ssc) + self.assertEqual(StreamingContext.getActiveOrCreate(self.cpd, setup), self.ssc) self.assertFalse(self.setupCalled) # Verify that getActiveOrCreate() uses existing SparkContext From 00a2911c5bea67a1a4796fb1d6fd5d0a8ee79001 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 18 Sep 2015 12:19:08 -0700 Subject: [PATCH 343/802] [SPARK-10540] Fixes flaky all-data-type test This PR breaks the original test case into multiple ones (one test case for each data type). In this way, test failure output can be much more readable. Within each test case, we build a table with two columns, one of them is for the data type to test, the other is an "index" column, which is used to sort the DataFrame and workaround [SPARK-10591] [1] [1]: https://issues.apache.org/jira/browse/SPARK-10591 Author: Cheng Lian Closes #8768 from liancheng/spark-10540/test-all-data-types. --- .../sql/sources/hadoopFsRelationSuites.scala | 109 +++++++----------- 1 file changed, 43 insertions(+), 66 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 8ffcef85668d6..d7504936d90e5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -100,80 +100,57 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } } - ignore("test all data types") { - withTempPath { file => - // Create the schema. - val struct = - StructType( - StructField("f1", FloatType, true) :: - StructField("f2", ArrayType(BooleanType), true) :: Nil) - // TODO: add CalendarIntervalType to here once we can save it out. - val dataTypes = - Seq( - StringType, BinaryType, NullType, BooleanType, - ByteType, ShortType, IntegerType, LongType, - FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), - DateType, TimestampType, - ArrayType(IntegerType), MapType(StringType, LongType), struct, - new MyDenseVectorUDT()) - val fields = dataTypes.zipWithIndex.map { case (dataType, index) => - StructField(s"col$index", dataType, nullable = true) - } - val schema = StructType(fields) - - // Generate data at the driver side. We need to materialize the data first and then - // create RDD. - val maybeDataGenerator = - RandomDataGenerator.forType( - dataType = schema, + private val supportedDataTypes = Seq( + StringType, BinaryType, + NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), + MapType(StringType, LongType), + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true), + new MyDenseVectorUDT() + ).filter(supportsDataType) + + for (dataType <- supportedDataTypes) { + test(s"test all data types - $dataType") { + withTempPath { file => + val path = file.getCanonicalPath + + val dataGenerator = RandomDataGenerator.forType( + dataType = dataType, nullable = true, - seed = Some(System.nanoTime())) - val dataGenerator = - maybeDataGenerator - .getOrElse(fail(s"Failed to create data generator for schema $schema")) - val data = (1 to 10).map { i => - dataGenerator.apply() match { - case row: Row => row - case null => Row.fromSeq(Seq.fill(schema.length)(null)) - case other => - fail(s"Row or null is expected to be generated, " + - s"but a ${other.getClass.getCanonicalName} is generated.") + seed = Some(System.nanoTime()) + ).getOrElse { + fail(s"Failed to create data generator for schema $dataType") } - } - // Create a DF for the schema with random data. - val rdd = sqlContext.sparkContext.parallelize(data, 10) - val df = sqlContext.createDataFrame(rdd, schema) + // Create a DF for the schema with random data. The index field is used to sort the + // DataFrame. This is a workaround for SPARK-10591. + val schema = new StructType() + .add("index", IntegerType, nullable = false) + .add("col", dataType, nullable = true) + val rdd = sqlContext.sparkContext.parallelize((1 to 10).map(i => Row(i, dataGenerator()))) + val df = sqlContext.createDataFrame(rdd, schema).orderBy("index").coalesce(1) - // All columns that have supported data types of this source. - val supportedColumns = schema.fields.collect { - case StructField(name, dataType, _, _) if supportsDataType(dataType) => name - } - val selectedColumns = util.Random.shuffle(supportedColumns.toSeq) - - val dfToBeSaved = df.selectExpr(selectedColumns: _*) - - // Save the data out. - dfToBeSaved - .write - .format(dataSourceName) - .option("dataSchema", dfToBeSaved.schema.json) // This option is just used by tests. - .save(file.getCanonicalPath) + df.write + .mode("overwrite") + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .save(path) - val loadedDF = - sqlContext + val loadedDF = sqlContext .read .format(dataSourceName) - .schema(dfToBeSaved.schema) - .option("dataSchema", dfToBeSaved.schema.json) // This option is just used by tests. - .load(file.getCanonicalPath) - .selectExpr(selectedColumns: _*) + .option("dataSchema", df.schema.json) + .schema(df.schema) + .load(path) + .orderBy("index") - // Read the data back. - checkAnswer( - loadedDF, - dfToBeSaved - ) + checkAnswer(loadedDF, df) + } } } From c6f8135ee52202bd86adb090ab631e80330ea4df Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Fri, 18 Sep 2015 13:20:13 -0700 Subject: [PATCH 344/802] [SPARK-10539] [SQL] Project should not be pushed down through Intersect or Except #8742 Intersect and Except are both set operators and they use the all the columns to compare equality between rows. When pushing their Project parent down, the relations they based on would change, therefore not an equivalent transformation. JIRA: https://issues.apache.org/jira/browse/SPARK-10539 I added some comments based on the fix of https://github.com/apache/spark/pull/8742. Author: Yijie Shen Author: Yin Huai Closes #8823 from yhuai/fix_set_optimization. --- .../sql/catalyst/optimizer/Optimizer.scala | 37 ++++++++++--------- .../optimizer/SetOperationPushDownSuite.scala | 23 ++++++------ .../org/apache/spark/sql/DataFrameSuite.scala | 9 +++++ 3 files changed, 39 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 648a65e7c0eb3..324f40a051c38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -85,7 +85,22 @@ object SamplePushDown extends Rule[LogicalPlan] { } /** - * Pushes operations to either side of a Union, Intersect or Except. + * Pushes certain operations to both sides of a Union, Intersect or Except operator. + * Operations that are safe to pushdown are listed as follows. + * Union: + * Right now, Union means UNION ALL, which does not de-duplicate rows. So, it is + * safe to pushdown Filters and Projections through it. Once we add UNION DISTINCT, + * we will not be able to pushdown Projections. + * + * Intersect: + * It is not safe to pushdown Projections through it because we need to get the + * intersect of rows by comparing the entire rows. It is fine to pushdown Filters + * because we will not have non-deterministic expressions. + * + * Except: + * It is not safe to pushdown Projections through it because we need to get the + * intersect of rows by comparing the entire rows. It is fine to pushdown Filters + * because we will not have non-deterministic expressions. */ object SetOperationPushDown extends Rule[LogicalPlan] { @@ -122,40 +137,26 @@ object SetOperationPushDown extends Rule[LogicalPlan] { Filter(condition, left), Filter(pushToRight(condition, rewrites), right)) - // Push down projection into union + // Push down projection through UNION ALL case Project(projectList, u @ Union(left, right)) => val rewrites = buildRewrites(u) Union( Project(projectList, left), Project(projectList.map(pushToRight(_, rewrites)), right)) - // Push down filter into intersect + // Push down filter through INTERSECT case Filter(condition, i @ Intersect(left, right)) => val rewrites = buildRewrites(i) Intersect( Filter(condition, left), Filter(pushToRight(condition, rewrites), right)) - // Push down projection into intersect - case Project(projectList, i @ Intersect(left, right)) => - val rewrites = buildRewrites(i) - Intersect( - Project(projectList, left), - Project(projectList.map(pushToRight(_, rewrites)), right)) - - // Push down filter into except + // Push down filter through EXCEPT case Filter(condition, e @ Except(left, right)) => val rewrites = buildRewrites(e) Except( Filter(condition, left), Filter(pushToRight(condition, rewrites), right)) - - // Push down projection into except - case Project(projectList, e @ Except(left, right)) => - val rewrites = buildRewrites(e) - Except( - Project(projectList, left), - Project(projectList.map(pushToRight(_, rewrites)), right)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala index 49c979bc7d72c..3fca47a023dc6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala @@ -60,23 +60,22 @@ class SetOperationPushDownSuite extends PlanTest { comparePlans(exceptOptimized, exceptCorrectAnswer) } - test("union/intersect/except: project to each side") { + test("union: project to each side") { val unionQuery = testUnion.select('a) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Union(testRelation.select('a), testRelation2.select('d)).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("SPARK-10539: Project should not be pushed down through Intersect or Except") { val intersectQuery = testIntersect.select('b, 'c) val exceptQuery = testExcept.select('a, 'b, 'c) - val unionOptimized = Optimize.execute(unionQuery.analyze) val intersectOptimized = Optimize.execute(intersectQuery.analyze) val exceptOptimized = Optimize.execute(exceptQuery.analyze) - val unionCorrectAnswer = - Union(testRelation.select('a), testRelation2.select('d)).analyze - val intersectCorrectAnswer = - Intersect(testRelation.select('b, 'c), testRelation2.select('e, 'f)).analyze - val exceptCorrectAnswer = - Except(testRelation.select('a, 'b, 'c), testRelation2.select('d, 'e, 'f)).analyze - - comparePlans(unionOptimized, unionCorrectAnswer) - comparePlans(intersectOptimized, intersectCorrectAnswer) - comparePlans(exceptOptimized, exceptCorrectAnswer) } + comparePlans(intersectOptimized, intersectQuery.analyze) + comparePlans(exceptOptimized, exceptQuery.analyze) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c167999af580e..1370713975f2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -907,4 +907,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(row.getDouble(1) - row.getDouble(3) === 0.0 +- 0.001) } } + + test("SPARK-10539: Project should not be pushed down through Intersect or Except") { + val df1 = (1 to 100).map(Tuple1.apply).toDF("i") + val df2 = (1 to 30).map(Tuple1.apply).toDF("i") + val intersect = df1.intersect(df2) + val except = df1.except(df2) + assert(intersect.count() === 30) + assert(except.count() === 70) + } } From 3a22b1004f527d54d399dd0225cd7f2f8ffad9c5 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 18 Sep 2015 13:47:14 -0700 Subject: [PATCH 345/802] [SPARK-10449] [SQL] Don't merge decimal types with incompatable precision or scales From JIRA: Schema merging should only handle struct fields. But currently we also reconcile decimal precision and scale information. Author: Holden Karau Closes #8634 from holdenk/SPARK-10449-dont-merge-different-precision. --- .../org/apache/spark/sql/types/StructType.scala | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index b29cf22dcb582..d6b436724b2a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -373,10 +373,19 @@ object StructType extends AbstractDataType { StructType(newFields) case (DecimalType.Fixed(leftPrecision, leftScale), - DecimalType.Fixed(rightPrecision, rightScale)) => - DecimalType( - max(leftScale, rightScale) + max(leftPrecision - leftScale, rightPrecision - rightScale), - max(leftScale, rightScale)) + DecimalType.Fixed(rightPrecision, rightScale)) => + if ((leftPrecision == rightPrecision) && (leftScale == rightScale)) { + DecimalType(leftPrecision, leftScale) + } else if ((leftPrecision != rightPrecision) && (leftScale != rightScale)) { + throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + s"precision $leftPrecision and $rightPrecision & scale $leftScale and $rightScale") + } else if (leftPrecision != rightPrecision) { + throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + s"precision $leftPrecision and $rightPrecision") + } else { + throw new SparkException("Failed to merge Decimal Tpes with incompatible " + + s"scala $leftScale and $rightScale") + } case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_]) if leftUdt.userClass == rightUdt.userClass => leftUdt From 348d7c9a93dd00d3d1859342a8eb0aea2e77f597 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 18 Sep 2015 13:48:41 -0700 Subject: [PATCH 346/802] [SPARK-9808] Remove hash shuffle file consolidation. Author: Reynold Xin Closes #8812 from rxin/SPARK-9808-1. --- .../shuffle/FileShuffleBlockResolver.scala | 178 ++---------------- .../apache/spark/storage/BlockManager.scala | 9 - .../org/apache/spark/storage/DiskStore.scala | 3 - .../hash/HashShuffleManagerSuite.scala | 110 ----------- docs/configuration.md | 10 - .../shuffle/ExternalShuffleBlockResolver.java | 4 - project/MimaExcludes.scala | 4 + 7 files changed, 17 insertions(+), 301 deletions(-) delete mode 100644 core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index c057de9b3f4df..d9902f96dfd4e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -17,9 +17,7 @@ package org.apache.spark.shuffle -import java.io.File import java.util.concurrent.ConcurrentLinkedQueue -import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ @@ -28,10 +26,8 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.FileShuffleBlockResolver.ShuffleFileGroup import org.apache.spark.storage._ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} -import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { @@ -43,24 +39,7 @@ private[spark] trait ShuffleWriterGroup { /** * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one file - * per reducer (this set of files is called a ShuffleFileGroup). - * - * As an optimization to reduce the number of physical shuffle files produced, multiple shuffle - * blocks are aggregated into the same file. There is one "combined shuffle file" per reducer - * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle - * files, it releases them for another task. - * Regarding the implementation of this feature, shuffle files are identified by a 3-tuple: - * - shuffleId: The unique id given to the entire shuffle stage. - * - bucketId: The id of the output partition (i.e., reducer id) - * - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a - * time owns a particular fileId, and this id is returned to a pool when the task finishes. - * Each shuffle file is then mapped to a FileSegment, which is a 3-tuple (file, offset, length) - * that specifies where in a given file the actual block data is located. - * - * Shuffle file metadata is stored in a space-efficient manner. Rather than simply mapping - * ShuffleBlockIds directly to FileSegments, each ShuffleFileGroup maintains a list of offsets for - * each block stored in each file. In order to find the location of a shuffle block, we search the - * files within a ShuffleFileGroups associated with the block's reducer. + * per reducer. */ // Note: Changes to the format in this file should be kept in sync with // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getHashBasedShuffleBlockData(). @@ -71,26 +50,15 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) private lazy val blockManager = SparkEnv.get.blockManager - // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. - // TODO: Remove this once the shuffle file consolidation feature is stable. - private val consolidateShuffleFiles = - conf.getBoolean("spark.shuffle.consolidateFiles", false) - // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val bufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 /** - * Contains all the state related to a particular shuffle. This includes a pool of unused - * ShuffleFileGroups, as well as all ShuffleFileGroups that have been created for the shuffle. + * Contains all the state related to a particular shuffle. */ - private class ShuffleState(val numBuckets: Int) { - val nextFileId = new AtomicInteger(0) - val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() - val allFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() - + private class ShuffleState(val numReducers: Int) { /** * The mapIds of all map tasks completed on this Executor for this shuffle. - * NB: This is only populated if consolidateShuffleFiles is FALSE. We don't need it otherwise. */ val completedMapTasks = new ConcurrentLinkedQueue[Int]() } @@ -104,24 +72,16 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) * Get a ShuffleWriterGroup for the given map task, which will register it as complete * when the writers are closed successfully */ - def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer, + def forMapTask(shuffleId: Int, mapId: Int, numReducers: Int, serializer: Serializer, writeMetrics: ShuffleWriteMetrics): ShuffleWriterGroup = { new ShuffleWriterGroup { - shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) + shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numReducers)) private val shuffleState = shuffleStates(shuffleId) - private var fileGroup: ShuffleFileGroup = null val openStartTime = System.nanoTime val serializerInstance = serializer.newInstance() - val writers: Array[DiskBlockObjectWriter] = if (consolidateShuffleFiles) { - fileGroup = getUnusedFileGroup() - Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => - val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) - blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializerInstance, bufferSize, - writeMetrics) - } - } else { - Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => + val writers: Array[DiskBlockObjectWriter] = { + Array.tabulate[DiskBlockObjectWriter](numReducers) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) val blockFile = blockManager.diskBlockManager.getFile(blockId) // Because of previous failures, the shuffle file may already exist on this machine. @@ -142,58 +102,14 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) writeMetrics.incShuffleWriteTime(System.nanoTime - openStartTime) override def releaseWriters(success: Boolean) { - if (consolidateShuffleFiles) { - if (success) { - val offsets = writers.map(_.fileSegment().offset) - val lengths = writers.map(_.fileSegment().length) - fileGroup.recordMapOutput(mapId, offsets, lengths) - } - recycleFileGroup(fileGroup) - } else { - shuffleState.completedMapTasks.add(mapId) - } - } - - private def getUnusedFileGroup(): ShuffleFileGroup = { - val fileGroup = shuffleState.unusedFileGroups.poll() - if (fileGroup != null) fileGroup else newFileGroup() - } - - private def newFileGroup(): ShuffleFileGroup = { - val fileId = shuffleState.nextFileId.getAndIncrement() - val files = Array.tabulate[File](numBuckets) { bucketId => - val filename = physicalFileName(shuffleId, bucketId, fileId) - blockManager.diskBlockManager.getFile(filename) - } - val fileGroup = new ShuffleFileGroup(shuffleId, fileId, files) - shuffleState.allFileGroups.add(fileGroup) - fileGroup - } - - private def recycleFileGroup(group: ShuffleFileGroup) { - shuffleState.unusedFileGroups.add(group) + shuffleState.completedMapTasks.add(mapId) } } } override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { - if (consolidateShuffleFiles) { - // Search all file groups associated with this shuffle. - val shuffleState = shuffleStates(blockId.shuffleId) - val iter = shuffleState.allFileGroups.iterator - while (iter.hasNext) { - val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId, blockId.reduceId) - if (segmentOpt.isDefined) { - val segment = segmentOpt.get - return new FileSegmentManagedBuffer( - transportConf, segment.file, segment.offset, segment.length) - } - } - throw new IllegalStateException("Failed to find shuffle block: " + blockId) - } else { - val file = blockManager.diskBlockManager.getFile(blockId) - new FileSegmentManagedBuffer(transportConf, file, 0, file.length) - } + val file = blockManager.diskBlockManager.getFile(blockId) + new FileSegmentManagedBuffer(transportConf, file, 0, file.length) } /** Remove all the blocks / files and metadata related to a particular shuffle. */ @@ -209,17 +125,9 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = { shuffleStates.get(shuffleId) match { case Some(state) => - if (consolidateShuffleFiles) { - for (fileGroup <- state.allFileGroups.asScala; - file <- fileGroup.files) { - file.delete() - } - } else { - for (mapId <- state.completedMapTasks.asScala; - reduceId <- 0 until state.numBuckets) { - val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) - blockManager.diskBlockManager.getFile(blockId).delete() - } + for (mapId <- state.completedMapTasks.asScala; reduceId <- 0 until state.numReducers) { + val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) + blockManager.diskBlockManager.getFile(blockId).delete() } logInfo("Deleted all files for shuffle " + shuffleId) true @@ -229,10 +137,6 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) } } - private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = { - "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId) - } - private def cleanup(cleanupTime: Long) { shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId)) } @@ -241,59 +145,3 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) metadataCleaner.cancel() } } - -private[spark] object FileShuffleBlockResolver { - /** - * A group of shuffle files, one per reducer. - * A particular mapper will be assigned a single ShuffleFileGroup to write its output to. - */ - private class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[File]) { - private var numBlocks: Int = 0 - - /** - * Stores the absolute index of each mapId in the files of this group. For instance, - * if mapId 5 is the first block in each file, mapIdToIndex(5) = 0. - */ - private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]() - - /** - * Stores consecutive offsets and lengths of blocks into each reducer file, ordered by - * position in the file. - * Note: mapIdToIndex(mapId) returns the index of the mapper into the vector for every - * reducer. - */ - private val blockOffsetsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { - new PrimitiveVector[Long]() - } - private val blockLengthsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { - new PrimitiveVector[Long]() - } - - def apply(bucketId: Int): File = files(bucketId) - - def recordMapOutput(mapId: Int, offsets: Array[Long], lengths: Array[Long]) { - assert(offsets.length == lengths.length) - mapIdToIndex(mapId) = numBlocks - numBlocks += 1 - for (i <- 0 until offsets.length) { - blockOffsetsByReducer(i) += offsets(i) - blockLengthsByReducer(i) += lengths(i) - } - } - - /** Returns the FileSegment associated with the given map task, or None if no entry exists. */ - def getFileSegmentFor(mapId: Int, reducerId: Int): Option[FileSegment] = { - val file = files(reducerId) - val blockOffsets = blockOffsetsByReducer(reducerId) - val blockLengths = blockLengthsByReducer(reducerId) - val index = mapIdToIndex.getOrElse(mapId, -1) - if (index >= 0) { - val offset = blockOffsets(index) - val length = blockLengths(index) - Some(new FileSegment(file, offset, length)) - } else { - None - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index d31aa68eb6954..bca3942f8c555 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -106,15 +106,6 @@ private[spark] class BlockManager( } } - // Check that we're not using external shuffle service with consolidated shuffle files. - if (externalShuffleServiceEnabled - && conf.getBoolean("spark.shuffle.consolidateFiles", false) - && shuffleManager.isInstanceOf[HashShuffleManager]) { - throw new UnsupportedOperationException("Cannot use external shuffle service with consolidated" - + " shuffle files in hash-based shuffle. Please disable spark.shuffle.consolidateFiles or " - + " switch to sort-based shuffle.") - } - var blockManagerId: BlockManagerId = _ // Address of the server that serves this executor's shuffle files. This is either an external diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 1f45956282166..feb9533604ffb 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -154,9 +154,6 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc override def remove(blockId: BlockId): Boolean = { val file = diskManager.getFile(blockId.name) - // If consolidation mode is used With HashShuffleMananger, the physical filename for the block - // is different from blockId.name. So the file returns here will not be exist, thus we avoid to - // delete the whole consolidated file by mistake. if (file.exists()) { file.delete() } else { diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala deleted file mode 100644 index 491dc3659e184..0000000000000 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.hash - -import java.io.{File, FileWriter} - -import scala.language.reflectiveCalls - -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} -import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.shuffle.FileShuffleBlockResolver -import org.apache.spark.storage.{ShuffleBlockId, FileSegment} - -class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext { - private val testConf = new SparkConf(false) - - private def checkSegments(expected: FileSegment, buffer: ManagedBuffer) { - assert(buffer.isInstanceOf[FileSegmentManagedBuffer]) - val segment = buffer.asInstanceOf[FileSegmentManagedBuffer] - assert(expected.file.getCanonicalPath === segment.getFile.getCanonicalPath) - assert(expected.offset === segment.getOffset) - assert(expected.length === segment.getLength) - } - - test("consolidated shuffle can write to shuffle group without messing existing offsets/lengths") { - - val conf = new SparkConf(false) - // reset after EACH object write. This is to ensure that there are bytes appended after - // an object is written. So if the codepaths assume writeObject is end of data, this should - // flush those bugs out. This was common bug in ExternalAppendOnlyMap, etc. - conf.set("spark.serializer.objectStreamReset", "1") - conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") - conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager") - - sc = new SparkContext("local", "test", conf) - - val shuffleBlockResolver = - SparkEnv.get.shuffleManager.shuffleBlockResolver.asInstanceOf[FileShuffleBlockResolver] - - val shuffle1 = shuffleBlockResolver.forMapTask(1, 1, 1, new JavaSerializer(conf), - new ShuffleWriteMetrics) - for (writer <- shuffle1.writers) { - writer.write("test1", "value") - writer.write("test2", "value") - } - for (writer <- shuffle1.writers) { - writer.commitAndClose() - } - - val shuffle1Segment = shuffle1.writers(0).fileSegment() - shuffle1.releaseWriters(success = true) - - val shuffle2 = shuffleBlockResolver.forMapTask(1, 2, 1, new JavaSerializer(conf), - new ShuffleWriteMetrics) - - for (writer <- shuffle2.writers) { - writer.write("test3", "value") - writer.write("test4", "vlue") - } - for (writer <- shuffle2.writers) { - writer.commitAndClose() - } - val shuffle2Segment = shuffle2.writers(0).fileSegment() - shuffle2.releaseWriters(success = true) - - // Now comes the test : - // Write to shuffle 3; and close it, but before registering it, check if the file lengths for - // previous task (forof shuffle1) is the same as 'segments'. Earlier, we were inferring length - // of block based on remaining data in file : which could mess things up when there is - // concurrent read and writes happening to the same shuffle group. - - val shuffle3 = shuffleBlockResolver.forMapTask(1, 3, 1, new JavaSerializer(testConf), - new ShuffleWriteMetrics) - for (writer <- shuffle3.writers) { - writer.write("test3", "value") - writer.write("test4", "value") - } - for (writer <- shuffle3.writers) { - writer.commitAndClose() - } - // check before we register. - checkSegments(shuffle2Segment, shuffleBlockResolver.getBlockData(ShuffleBlockId(1, 2, 0))) - shuffle3.releaseWriters(success = true) - checkSegments(shuffle2Segment, shuffleBlockResolver.getBlockData(ShuffleBlockId(1, 2, 0))) - shuffleBlockResolver.removeShuffle(1) - } - - def writeToFile(file: File, numBytes: Int) { - val writer = new FileWriter(file, true) - for (i <- 0 until numBytes) writer.write(i) - writer.close() - } -} diff --git a/docs/configuration.md b/docs/configuration.md index 1a701f18881fe..3700051efb448 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -390,16 +390,6 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.codec. - - spark.shuffle.consolidateFiles - false - - If set to "true", consolidates intermediate files created during a shuffle. Creating fewer - files can improve filesystem performance for shuffles with large numbers of reduce tasks. It - is recommended to set this to "true" when using ext4 or xfs filesystems. On ext3, this option - might degrade performance on machines with many (>8) cores due to filesystem limitations. - - spark.shuffle.file.buffer 32k diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 79beec4429a99..c5f93bb47f55c 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -50,9 +50,6 @@ * of Executors. Each Executor must register its own configuration about where it stores its files * (local dirs) and how (shuffle manager). The logic for retrieval of individual files is replicated * from Spark's FileShuffleBlockResolver and IndexShuffleBlockResolver. - * - * Executors with shuffle file consolidation are not currently supported, as the index is stored in - * the Executor's memory, unlike the IndexShuffleBlockResolver. */ public class ExternalShuffleBlockResolver { private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockResolver.class); @@ -254,7 +251,6 @@ private void deleteExecutorDirs(String[] dirs) { * Hash-based shuffle data is simply stored as one file per block. * This logic is from FileShuffleBlockResolver. */ - // TODO: Support consolidated hash shuffle files private ManagedBuffer getHashBasedShuffleBlockData(ExecutorShuffleInfo executor, String blockId) { File shuffleFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, blockId); return new FileSegmentManagedBuffer(conf, shuffleFile, 0, shuffleFile.length()); diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 1c96b0958586f..814a11e588ceb 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -70,6 +70,10 @@ object MimaExcludes { "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"), ProblemFilters.exclude[IncompatibleMethTypeProblem]( "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply") + ) ++ + Seq( + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.shuffle.FileShuffleBlockResolver$ShuffleFileGroup") ) case v if v.startsWith("1.5") => Seq( From 8074208fa47fa654c1055c48cfa0d923edeeb04f Mon Sep 17 00:00:00 2001 From: Mingyu Kim Date: Fri, 18 Sep 2015 15:40:58 -0700 Subject: [PATCH 347/802] [SPARK-10611] Clone Configuration for each task for NewHadoopRDD This patch attempts to fix the Hadoop Configuration thread safety issue for NewHadoopRDD in the same way SPARK-2546 fixed the issue for HadoopRDD. Author: Mingyu Kim Closes #8763 from mingyukim/mkim/SPARK-10611. --- .../org/apache/spark/rdd/BinaryFileRDD.scala | 5 ++- .../org/apache/spark/rdd/NewHadoopRDD.scala | 37 ++++++++++++++++--- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index 6fec00dcd0d85..aedced7408cde 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -34,12 +34,13 @@ private[spark] class BinaryFileRDD[T]( override def getPartitions: Array[Partition] = { val inputFormat = inputFormatClass.newInstance + val conf = getConf inputFormat match { case configurable: Configurable => - configurable.setConf(getConf) + configurable.setConf(conf) case _ => } - val jobContext = newJobContext(getConf, jobId) + val jobContext = newJobContext(conf, jobId) inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 174979aaeb231..2872b93b8730e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -44,7 +44,6 @@ private[spark] class NewHadoopPartition( extends Partition { val serializableHadoopSplit = new SerializableWritable(rawSplit) - override def hashCode(): Int = 41 * (41 + rddId) + index } @@ -84,6 +83,27 @@ class NewHadoopRDD[K, V]( @transient protected val jobId = new JobID(jobTrackerId, id) + private val shouldCloneJobConf = sparkContext.conf.getBoolean("spark.hadoop.cloneConf", false) + + def getConf: Configuration = { + val conf: Configuration = confBroadcast.value.value + if (shouldCloneJobConf) { + // Hadoop Configuration objects are not thread-safe, which may lead to various problems if + // one job modifies a configuration while another reads it (SPARK-2546, SPARK-10611). This + // problem occurs somewhat rarely because most jobs treat the configuration as though it's + // immutable. One solution, implemented here, is to clone the Configuration object. + // Unfortunately, this clone can be very expensive. To avoid unexpected performance + // regressions for workloads and Hadoop versions that do not suffer from these thread-safety + // issues, this cloning is disabled by default. + NewHadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { + logDebug("Cloning Hadoop Configuration") + new Configuration(conf) + } + } else { + conf + } + } + override def getPartitions: Array[Partition] = { val inputFormat = inputFormatClass.newInstance inputFormat match { @@ -104,7 +124,7 @@ class NewHadoopRDD[K, V]( val iter = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[NewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) - val conf = confBroadcast.value.value + val conf = getConf val inputMetrics = context.taskMetrics .getInputMetricsForReadMethod(DataReadMethod.Hadoop) @@ -230,11 +250,15 @@ class NewHadoopRDD[K, V]( super.persist(storageLevel) } - - def getConf: Configuration = confBroadcast.value.value } private[spark] object NewHadoopRDD { + /** + * Configuration's constructor is not threadsafe (see SPARK-1097 and HADOOP-10456). + * Therefore, we synchronize on this lock before calling new Configuration(). + */ + val CONFIGURATION_INSTANTIATION_LOCK = new Object() + /** * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to * the given function rather than the index of the partition. @@ -268,12 +292,13 @@ private[spark] class WholeTextFileRDD( override def getPartitions: Array[Partition] = { val inputFormat = inputFormatClass.newInstance + val conf = getConf inputFormat match { case configurable: Configurable => - configurable.setConf(getConf) + configurable.setConf(conf) case _ => } - val jobContext = newJobContext(getConf, jobId) + val jobContext = newJobContext(conf, jobId) inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) From c8149ef2c57f5c47ab97ee8d8d58a216d4bc4156 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 18 Sep 2015 16:23:05 -0700 Subject: [PATCH 348/802] [MINOR] [ML] override toString of AttributeGroup This makes equality test failures much more readable. mengxr Author: Eric Liang Author: Eric Liang Closes #8826 from ericl/attrgroupstr. --- .../scala/org/apache/spark/ml/attribute/AttributeGroup.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala index 457c15830fd38..2c29eeb01a921 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala @@ -183,6 +183,8 @@ class AttributeGroup private ( sum = 37 * sum + attributes.map(_.toSeq).hashCode sum } + + override def toString: String = toMetadata.toString } /** From 22be2ae147a111e88896f6fb42ed46bbf108a99b Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 18 Sep 2015 18:42:20 -0700 Subject: [PATCH 349/802] [SPARK-10623] [SQL] Fixes ORC predicate push-down When pushing down a leaf predicate, ORC `SearchArgument` builder requires an extra "parent" predicate (any one among `AND`/`OR`/`NOT`) to wrap the leaf predicate. E.g., to push down `a < 1`, we must build `AND(a < 1)` instead. Fortunately, when actually constructing the `SearchArgument`, the builder will eliminate all those unnecessary wrappers. This PR is based on #8783 authored by zhzhan. I also took the chance to simply `OrcFilters` a little bit to improve readability. Author: Cheng Lian Closes #8799 from liancheng/spark-10623/fix-orc-ppd. --- .../spark/sql/hive/orc/OrcFilters.scala | 56 ++++++++----------- .../spark/sql/hive/orc/OrcQuerySuite.scala | 30 ++++++++++ 2 files changed, 52 insertions(+), 34 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index b3d9f7f71a27d..27193f54d3a91 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -31,11 +31,13 @@ import org.apache.spark.sql.sources._ * and cannot be used anymore. */ private[orc] object OrcFilters extends Logging { - def createFilter(expr: Array[Filter]): Option[SearchArgument] = { - expr.reduceOption(And).flatMap { conjunction => - val builder = SearchArgumentFactory.newBuilder() - buildSearchArgument(conjunction, builder).map(_.build()) - } + def createFilter(filters: Array[Filter]): Option[SearchArgument] = { + for { + // Combines all filters with `And`s to produce a single conjunction predicate + conjunction <- filters.reduceOption(And) + // Then tries to build a single ORC `SearchArgument` for the conjunction predicate + builder <- buildSearchArgument(conjunction, SearchArgumentFactory.newBuilder()) + } yield builder.build() } private def buildSearchArgument(expression: Filter, builder: Builder): Option[Builder] = { @@ -102,46 +104,32 @@ private[orc] object OrcFilters extends Logging { negate <- buildSearchArgument(child, builder.startNot()) } yield negate.end() - case EqualTo(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.equals(attribute, _)) + case EqualTo(attribute, value) if isSearchableLiteral(value) => + Some(builder.startAnd().equals(attribute, value).end()) - case EqualNullSafe(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.nullSafeEquals(attribute, _)) + case EqualNullSafe(attribute, value) if isSearchableLiteral(value) => + Some(builder.startAnd().nullSafeEquals(attribute, value).end()) - case LessThan(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.lessThan(attribute, _)) + case LessThan(attribute, value) if isSearchableLiteral(value) => + Some(builder.startAnd().lessThan(attribute, value).end()) - case LessThanOrEqual(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.lessThanEquals(attribute, _)) + case LessThanOrEqual(attribute, value) if isSearchableLiteral(value) => + Some(builder.startAnd().lessThanEquals(attribute, value).end()) - case GreaterThan(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.startNot().lessThanEquals(attribute, _).end()) + case GreaterThan(attribute, value) if isSearchableLiteral(value) => + Some(builder.startNot().lessThanEquals(attribute, value).end()) - case GreaterThanOrEqual(attribute, value) => - Option(value) - .filter(isSearchableLiteral) - .map(builder.startNot().lessThan(attribute, _).end()) + case GreaterThanOrEqual(attribute, value) if isSearchableLiteral(value) => + Some(builder.startNot().lessThan(attribute, value).end()) case IsNull(attribute) => - Some(builder.isNull(attribute)) + Some(builder.startAnd().isNull(attribute).end()) case IsNotNull(attribute) => Some(builder.startNot().isNull(attribute).end()) - case In(attribute, values) => - Option(values) - .filter(_.forall(isSearchableLiteral)) - .map(builder.in(attribute, _)) + case In(attribute, values) if values.forall(isSearchableLiteral) => + Some(builder.startAnd().in(attribute, values.map(_.asInstanceOf[AnyRef]): _*).end()) case _ => None } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 8bc33fcf5d906..5eb39b1129701 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -344,4 +344,34 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } } + + test("SPARK-10623 Enable ORC PPD") { + withTempPath { dir => + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { + import testImplicits._ + + val path = dir.getCanonicalPath + sqlContext.range(10).coalesce(1).write.orc(path) + val df = sqlContext.read.orc(path) + + def checkPredicate(pred: Column, answer: Seq[Long]): Unit = { + checkAnswer(df.where(pred), answer.map(Row(_))) + } + + checkPredicate('id === 5, Seq(5L)) + checkPredicate('id <=> 5, Seq(5L)) + checkPredicate('id < 5, 0L to 4L) + checkPredicate('id <= 5, 0L to 5L) + checkPredicate('id > 5, 6L to 9L) + checkPredicate('id >= 5, 5L to 9L) + checkPredicate('id.isNull, Seq.empty[Long]) + checkPredicate('id.isNotNull, 0L to 9L) + checkPredicate('id.isin(1L, 3L, 5L), Seq(1L, 3L, 5L)) + checkPredicate('id > 0 && 'id < 3, 1L to 2L) + checkPredicate('id < 1 || 'id > 8, Seq(0L, 9L)) + checkPredicate(!('id > 3), 0L to 3L) + checkPredicate(!('id > 0 && 'id < 3), Seq(0L) ++ (3L to 9L)) + } + } + } } From 7ff8d68cc19299e16dedfd819b9e96480fa6cf44 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 18 Sep 2015 23:58:25 -0700 Subject: [PATCH 350/802] [SPARK-10474] [SQL] Aggregation fails to allocate memory for pointer array When `TungstenAggregation` hits memory pressure, it switches from hash-based to sort-based aggregation in-place. However, in the process we try to allocate the pointer array for writing to the new `UnsafeExternalSorter` *before* actually freeing the memory from the hash map. This lead to the following exception: ``` java.io.IOException: Could not acquire 65536 bytes of memory at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.initializeForWriting(UnsafeExternalSorter.java:169) at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.spill(UnsafeExternalSorter.java:220) at org.apache.spark.sql.execution.UnsafeKVExternalSorter.(UnsafeKVExternalSorter.java:126) at org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap.destructAndCreateExternalSorter(UnsafeFixedWidthAggregationMap.java:257) at org.apache.spark.sql.execution.aggregate.TungstenAggregationIterator.switchToSortBasedAggregation(TungstenAggregationIterator.scala:435) ``` Author: Andrew Or Closes #8827 from andrewor14/allocate-pointer-array. --- .../unsafe/sort/UnsafeExternalSorter.java | 14 +++++- .../sql/execution/UnsafeKVExternalSorter.java | 8 ++- .../UnsafeFixedWidthAggregationMapSuite.scala | 49 ++++++++++++++++++- 3 files changed, 66 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index fc364e0a895b1..14b6aafdea7df 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -159,7 +159,7 @@ public BoxedUnit apply() { /** * Allocates new sort data structures. Called when creating the sorter and after each spill. */ - private void initializeForWriting() throws IOException { + public void initializeForWriting() throws IOException { this.writeMetrics = new ShuffleWriteMetrics(); final long pointerArrayMemory = UnsafeInMemorySorter.getMemoryRequirementsForPointerArray(initialSize); @@ -187,6 +187,14 @@ public void closeCurrentPage() { * Sort and spill the current records in response to memory pressure. */ public void spill() throws IOException { + spill(true); + } + + /** + * Sort and spill the current records in response to memory pressure. + * @param shouldInitializeForWriting whether to allocate memory for writing after the spill + */ + public void spill(boolean shouldInitializeForWriting) throws IOException { assert(inMemSorter != null); logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", Thread.currentThread().getId(), @@ -217,7 +225,9 @@ public void spill() throws IOException { // written to disk. This also counts the space needed to store the sorter's pointer array. taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); - initializeForWriting(); + if (shouldInitializeForWriting) { + initializeForWriting(); + } } /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 7db6b7ff50f22..b81f67a16b815 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -85,6 +85,7 @@ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, // We will use the number of elements in the map as the initialSize of the // UnsafeInMemorySorter. Because UnsafeInMemorySorter does not accept 0 as the initialSize, // we will use 1 as its initial size if the map is empty. + // TODO: track pointer array memory used by this in-memory sorter! final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( taskMemoryManager, recordComparator, prefixComparator, Math.max(1, map.numElements())); @@ -123,8 +124,13 @@ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, pageSizeBytes, inMemSorter); - sorter.spill(); + // Note: This spill doesn't actually release any memory, so if we try to allocate a new + // pointer array immediately after the spill then we may fail to acquire sufficient space + // for it (SPARK-10474). For this reason, we must initialize for writing explicitly *after* + // we have actually freed memory from our map. + sorter.spill(false /* initialize for writing */); map.free(); + sorter.initializeForWriting(); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index d1f0b2b1fc52f..ada4d42f991ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -23,9 +23,10 @@ import scala.util.{Try, Random} import org.scalatest.Matchers -import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite} +import org.apache.spark.shuffle.ShuffleMemoryManager import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} @@ -325,7 +326,7 @@ class UnsafeFixedWidthAggregationMapSuite // At here, we also test if copy is correct. iter.getKey.copy() iter.getValue.copy() - count += 1; + count += 1 } // 1 record was from the map and 4096 records were explicitly inserted. @@ -333,4 +334,48 @@ class UnsafeFixedWidthAggregationMapSuite map.free() } + + testWithMemoryLeakDetection("convert to external sorter under memory pressure (SPARK-10474)") { + val smm = ShuffleMemoryManager.createForTesting(65536) + val pageSize = 4096 + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + taskMemoryManager, + smm, + 128, // initial capacity + pageSize, + false // disable perf metrics + ) + + // Insert into the map until we've run out of space + val rand = new Random(42) + var hasSpace = true + while (hasSpace) { + val str = rand.nextString(1024) + val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str))) + if (buf == null) { + hasSpace = false + } else { + buf.setInt(0, str.length) + } + } + + // Ensure we're actually maxed out by asserting that we can't acquire even just 1 byte + assert(smm.tryToAcquire(1) === 0) + + // Convert the map into a sorter. This used to fail before the fix for SPARK-10474 + // because we would try to acquire space for the in-memory sorter pointer array before + // actually releasing the pages despite having spilled all of them. + var sorter: UnsafeKVExternalSorter = null + try { + sorter = map.destructAndCreateExternalSorter() + } finally { + if (sorter != null) { + sorter.cleanupResources() + } + } + } + } From d507f9c0b7f7a524137a694ed6443747aaf90463 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sat, 19 Sep 2015 01:59:36 -0700 Subject: [PATCH 351/802] [SPARK-10584] [SQL] [DOC] Documentation about the compatible Hive version is wrong. In Spark 1.5.0, Spark SQL is compatible with Hive 0.12.0 through 1.2.1 but the documentation is wrong. /CC yhuai Author: Kousuke Saruta Closes #8776 from sarutak/SPARK-10584-2. --- docs/sql-programming-guide.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index a0b911d207243..82d4243cc6b27 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1954,7 +1954,7 @@ without the need to write any code. ## Running the Thrift JDBC/ODBC server The Thrift JDBC/ODBC server implemented here corresponds to the [`HiveServer2`](https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2) -in Hive 0.13. You can test the JDBC server with the beeline script that comes with either Spark or Hive 0.13. +in Hive 1.2.1 You can test the JDBC server with the beeline script that comes with either Spark or Hive 1.2.1. To start the JDBC/ODBC server, run the following in the Spark directory: @@ -2260,8 +2260,10 @@ Several caching related features are not supported yet: ## Compatibility with Apache Hive -Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently Spark -SQL is based on Hive 0.12.0 and 0.13.1. +Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. +Currently Hive SerDes and UDFs are based on Hive 1.2.1, +and Spark SQL can be connected to different versions of Hive Metastore +(from 0.12.0 to 1.2.1. Also see http://spark.apache.org/docs/latest/sql-programming-guide.html#interacting-with-different-versions-of-hive-metastore). #### Deploying in Existing Hive Warehouses From d83b6aae8b4357c56779cc98804eb350ab8af62d Mon Sep 17 00:00:00 2001 From: Alexis Seigneurin Date: Sat, 19 Sep 2015 12:01:22 +0100 Subject: [PATCH 352/802] Fixed links to the API Submitting this change on the master branch as requested in https://github.com/apache/spark/pull/8819#issuecomment-141505941 Author: Alexis Seigneurin Closes #8838 from aseigneurin/patch-2. --- docs/ml-guide.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index c5d7f990021f1..0427ac6695aa1 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -619,13 +619,13 @@ for row in selected.collect(): An important task in ML is *model selection*, or using data to find the best model or parameters for a given task. This is also called *tuning*. `Pipeline`s facilitate model selection by making it easy to tune an entire `Pipeline` at once, rather than tuning each element in the `Pipeline` separately. -Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) class, which takes an `Estimator`, a set of `ParamMap`s, and an [`Evaluator`](api/scala/index.html#org.apache.spark.ml.Evaluator). +Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) class, which takes an `Estimator`, a set of `ParamMap`s, and an [`Evaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.Evaluator). `CrossValidator` begins by splitting the dataset into a set of *folds* which are used as separate training and test datasets; e.g., with `$k=3$` folds, `CrossValidator` will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. `CrossValidator` iterates through the set of `ParamMap`s. For each `ParamMap`, it trains the given `Estimator` and evaluates it using the given `Evaluator`. -The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.RegressionEvaluator) -for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.BinaryClassificationEvaluator) -for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.MultiClassClassificationEvaluator) +The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.RegressionEvaluator) +for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.BinaryClassificationEvaluator) +for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MultiClassClassificationEvaluator) for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the `setMetric` method in each of these evaluators. From e789000b88a6bd840f821c53f42c08b97dc02496 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sat, 19 Sep 2015 18:22:43 -0700 Subject: [PATCH 353/802] [SPARK-10155] [SQL] Change SqlParser to object to avoid memory leak Since `scala.util.parsing.combinator.Parsers` is thread-safe since Scala 2.10 (See [SI-4929](https://issues.scala-lang.org/browse/SI-4929)), we can change SqlParser to object to avoid memory leak. I didn't change other subclasses of `scala.util.parsing.combinator.Parsers` because there is only one instance in one SQLContext, which should not be an issue. Author: zsxwing Closes #8357 from zsxwing/sql-memory-leak. --- .../apache/spark/sql/catalyst/AbstractSparkSQLParser.scala | 2 +- .../scala/org/apache/spark/sql/catalyst/ParserDialect.scala | 2 +- .../scala/org/apache/spark/sql/catalyst/SqlParser.scala | 6 +++--- .../src/main/scala/org/apache/spark/sql/DataFrame.scala | 6 +++--- .../main/scala/org/apache/spark/sql/DataFrameWriter.scala | 4 ++-- .../src/main/scala/org/apache/spark/sql/SQLContext.scala | 6 +++--- .../src/main/scala/org/apache/spark/sql/functions.scala | 2 +- .../main/scala/org/apache/spark/sql/hive/HiveContext.scala | 6 +++--- .../org/apache/spark/sql/hive/HiveMetastoreCatalog.scala | 4 ++-- 9 files changed, 19 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index 5898a5f93f381..2bac08eac4fe2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ private[sql] abstract class AbstractSparkSQLParser extends StandardTokenParsers with PackratParsers { - def parse(input: String): LogicalPlan = { + def parse(input: String): LogicalPlan = synchronized { // Initialize the Keywords. initLexical phrase(start)(new lexical.Scanner(input)) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala index 554fb4eb25eb1..e21d3c05464b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala @@ -61,7 +61,7 @@ abstract class ParserDialect { */ private[spark] class DefaultParserDialect extends ParserDialect { @transient - protected val sqlParser = new SqlParser + protected val sqlParser = SqlParser override def parse(sqlText: String): LogicalPlan = { sqlParser.parse(sqlText) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index f2498861c9573..dfab2398857e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -37,9 +37,9 @@ import org.apache.spark.unsafe.types.CalendarInterval * This is currently included mostly for illustrative purposes. Users wanting more complete support * for a SQL like language should checkout the HiveQL support in the sql/hive sub-project. */ -class SqlParser extends AbstractSparkSQLParser with DataTypeParser { +object SqlParser extends AbstractSparkSQLParser with DataTypeParser { - def parseExpression(input: String): Expression = { + def parseExpression(input: String): Expression = synchronized { // Initialize the Keywords. initLexical phrase(projection)(new lexical.Scanner(input)) match { @@ -48,7 +48,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } } - def parseTableIdentifier(input: String): TableIdentifier = { + def parseTableIdentifier(input: String): TableIdentifier = synchronized { // Initialize the Keywords. initLexical phrase(tableIdentifier)(new lexical.Scanner(input)) match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 3e61123c145cd..8f737c2023931 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -720,7 +720,7 @@ class DataFrame private[sql]( @scala.annotation.varargs def selectExpr(exprs: String*): DataFrame = { select(exprs.map { expr => - Column(new SqlParser().parseExpression(expr)) + Column(SqlParser.parseExpression(expr)) }: _*) } @@ -745,7 +745,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ def filter(conditionExpr: String): DataFrame = { - filter(Column(new SqlParser().parseExpression(conditionExpr))) + filter(Column(SqlParser.parseExpression(conditionExpr))) } /** @@ -769,7 +769,7 @@ class DataFrame private[sql]( * @since 1.5.0 */ def where(conditionExpr: String): DataFrame = { - filter(Column(new SqlParser().parseExpression(conditionExpr))) + filter(Column(SqlParser.parseExpression(conditionExpr))) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 745bb4ec9cf1c..03e973666e888 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -163,7 +163,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def insertInto(tableName: String): Unit = { - insertInto(new SqlParser().parseTableIdentifier(tableName)) + insertInto(SqlParser.parseTableIdentifier(tableName)) } private def insertInto(tableIdent: TableIdentifier): Unit = { @@ -197,7 +197,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def saveAsTable(tableName: String): Unit = { - saveAsTable(new SqlParser().parseTableIdentifier(tableName)) + saveAsTable(SqlParser.parseTableIdentifier(tableName)) } private def saveAsTable(tableIdent: TableIdentifier): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index e3fdd782e6ff6..f099940800cc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -590,7 +590,7 @@ class SQLContext(@transient val sparkContext: SparkContext) tableName: String, source: String, options: Map[String, String]): DataFrame = { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val tableIdent = SqlParser.parseTableIdentifier(tableName) val cmd = CreateTableUsing( tableIdent, @@ -636,7 +636,7 @@ class SQLContext(@transient val sparkContext: SparkContext) source: String, schema: StructType, options: Map[String, String]): DataFrame = { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val tableIdent = SqlParser.parseTableIdentifier(tableName) val cmd = CreateTableUsing( tableIdent, @@ -732,7 +732,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @since 1.3.0 */ def table(tableName: String): DataFrame = { - table(new SqlParser().parseTableIdentifier(tableName)) + table(SqlParser.parseTableIdentifier(tableName)) } private def table(tableIdent: TableIdentifier): DataFrame = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 60d9c509104d5..2467b4e48415b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -823,7 +823,7 @@ object functions { * * @group normal_funcs */ - def expr(expr: String): Column = Column(new SqlParser().parseExpression(expr)) + def expr(expr: String): Column = Column(SqlParser.parseExpression(expr)) ////////////////////////////////////////////////////////////////////////////////////////////// // Math Functions diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index d37ba5ddc2d80..c12a734863326 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -291,12 +291,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { * @since 1.3.0 */ def refreshTable(tableName: String): Unit = { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val tableIdent = SqlParser.parseTableIdentifier(tableName) catalog.refreshTable(tableIdent) } protected[hive] def invalidateTable(tableName: String): Unit = { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val tableIdent = SqlParser.parseTableIdentifier(tableName) catalog.invalidateTable(tableIdent) } @@ -311,7 +311,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { */ @Experimental def analyze(tableName: String) { - val tableIdent = new SqlParser().parseTableIdentifier(tableName) + val tableIdent = SqlParser.parseTableIdentifier(tableName) val relation = EliminateSubQueries(catalog.lookupRelation(tableIdent.toSeq)) relation match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 0a5569b0a4446..0c1b41e3377e3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -199,7 +199,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive options: Map[String, String], isExternal: Boolean): Unit = { createDataSourceTable( - new SqlParser().parseTableIdentifier(tableName), + SqlParser.parseTableIdentifier(tableName), userSpecifiedSchema, partitionColumns, provider, @@ -375,7 +375,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } def hiveDefaultTableFilePath(tableName: String): String = { - hiveDefaultTableFilePath(new SqlParser().parseTableIdentifier(tableName)) + hiveDefaultTableFilePath(SqlParser.parseTableIdentifier(tableName)) } def hiveDefaultTableFilePath(tableIdent: TableIdentifier): String = { From 2117eea71ece825fbc3797c8b38184ae221f5223 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 19 Sep 2015 21:40:21 -0700 Subject: [PATCH 354/802] [SPARK-10710] Remove ability to disable spilling in core and SQL It does not make much sense to set `spark.shuffle.spill` or `spark.sql.planner.externalSort` to false: I believe that these configurations were initially added as "escape hatches" to guard against bugs in the external operators, but these operators are now mature and well-tested. In addition, these configurations are not handled in a consistent way anymore: SQL's Tungsten codepath ignores these configurations and will continue to use spilling operators. Similarly, Spark Core's `tungsten-sort` shuffle manager does not respect `spark.shuffle.spill=false`. This pull request removes these configurations, adds warnings at the appropriate places, and deletes a large amount of code which was only used in code paths that did not support spilling. Author: Josh Rosen Closes #8831 from JoshRosen/remove-ability-to-disable-spilling. --- .../scala/org/apache/spark/Aggregator.scala | 59 +++++-------------- .../org/apache/spark/rdd/CoGroupedRDD.scala | 40 ++++--------- .../shuffle/hash/HashShuffleManager.scala | 8 ++- .../shuffle/sort/SortShuffleManager.scala | 10 +++- .../util/collection/ExternalSorter.scala | 6 -- .../spark/deploy/SparkSubmitSuite.scala | 22 +++---- docs/configuration.md | 14 +---- docs/sql-programming-guide.md | 7 --- python/pyspark/rdd.py | 25 +++----- python/pyspark/shuffle.py | 30 ---------- python/pyspark/tests.py | 13 +--- .../scala/org/apache/spark/sql/SQLConf.scala | 8 +-- .../spark/sql/execution/SparkStrategies.scala | 2 - .../apache/spark/sql/execution/commands.scala | 9 +++ .../org/apache/spark/sql/execution/sort.scala | 30 +--------- .../org/apache/spark/sql/SQLQuerySuite.scala | 26 ++------ .../execution/RowFormatConvertersSuite.scala | 2 +- .../spark/sql/execution/SortSuite.scala | 4 +- 18 files changed, 81 insertions(+), 234 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index 289aab9bd9e51..7196e57d5d2e2 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -18,7 +18,7 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.collection.{AppendOnlyMap, ExternalAppendOnlyMap} +import org.apache.spark.util.collection.ExternalAppendOnlyMap /** * :: DeveloperApi :: @@ -34,59 +34,30 @@ case class Aggregator[K, V, C] ( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) { - // When spilling is enabled sorting will happen externally, but not necessarily with an - // ExternalSorter. - private val isSpillEnabled = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true) - @deprecated("use combineValuesByKey with TaskContext argument", "0.9.0") def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]): Iterator[(K, C)] = combineValuesByKey(iter, null) - def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]], - context: TaskContext): Iterator[(K, C)] = { - if (!isSpillEnabled) { - val combiners = new AppendOnlyMap[K, C] - var kv: Product2[K, V] = null - val update = (hadValue: Boolean, oldValue: C) => { - if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) - } - while (iter.hasNext) { - kv = iter.next() - combiners.changeValue(kv._1, update) - } - combiners.iterator - } else { - val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) - combiners.insertAll(iter) - updateMetrics(context, combiners) - combiners.iterator - } + def combineValuesByKey( + iter: Iterator[_ <: Product2[K, V]], + context: TaskContext): Iterator[(K, C)] = { + val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) + combiners.insertAll(iter) + updateMetrics(context, combiners) + combiners.iterator } @deprecated("use combineCombinersByKey with TaskContext argument", "0.9.0") def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]]) : Iterator[(K, C)] = combineCombinersByKey(iter, null) - def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]], context: TaskContext) - : Iterator[(K, C)] = - { - if (!isSpillEnabled) { - val combiners = new AppendOnlyMap[K, C] - var kc: Product2[K, C] = null - val update = (hadValue: Boolean, oldValue: C) => { - if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2 - } - while (iter.hasNext) { - kc = iter.next() - combiners.changeValue(kc._1, update) - } - combiners.iterator - } else { - val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) - combiners.insertAll(iter) - updateMetrics(context, combiners) - combiners.iterator - } + def combineCombinersByKey( + iter: Iterator[_ <: Product2[K, C]], + context: TaskContext): Iterator[(K, C)] = { + val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) + combiners.insertAll(iter) + updateMetrics(context, combiners) + combiners.iterator } /** Update task metrics after populating the external map. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 7bad749d58327..935c3babd8ea1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -26,7 +26,7 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer} +import org.apache.spark.util.collection.{CompactBuffer, ExternalAppendOnlyMap} import org.apache.spark.util.Utils import org.apache.spark.serializer.Serializer @@ -128,8 +128,6 @@ class CoGroupedRDD[K: ClassTag]( override val partitioner: Some[Partitioner] = Some(part) override def compute(s: Partition, context: TaskContext): Iterator[(K, Array[Iterable[_]])] = { - val sparkConf = SparkEnv.get.conf - val externalSorting = sparkConf.getBoolean("spark.shuffle.spill", true) val split = s.asInstanceOf[CoGroupPartition] val numRdds = dependencies.length @@ -150,34 +148,16 @@ class CoGroupedRDD[K: ClassTag]( rddIterators += ((it, depNum)) } - if (!externalSorting) { - val map = new AppendOnlyMap[K, CoGroupCombiner] - val update: (Boolean, CoGroupCombiner) => CoGroupCombiner = (hadVal, oldVal) => { - if (hadVal) oldVal else Array.fill(numRdds)(new CoGroup) - } - val getCombiner: K => CoGroupCombiner = key => { - map.changeValue(key, update) - } - rddIterators.foreach { case (it, depNum) => - while (it.hasNext) { - val kv = it.next() - getCombiner(kv._1)(depNum) += kv._2 - } - } - new InterruptibleIterator(context, - map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) - } else { - val map = createExternalMap(numRdds) - for ((it, depNum) <- rddIterators) { - map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum)))) - } - context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) - context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) - context.internalMetricsToAccumulators( - InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) - new InterruptibleIterator(context, - map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) + val map = createExternalMap(numRdds) + for ((it, depNum) <- rddIterators) { + map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum)))) } + context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled) + context.internalMetricsToAccumulators( + InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes) + new InterruptibleIterator(context, + map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]]) } private def createExternalMap(numRdds: Int) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index c089088f409dd..0b46634b8b466 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -24,7 +24,13 @@ import org.apache.spark.shuffle._ * A ShuffleManager using hashing, that creates one output file per reduce partition on each * mapper (possibly reusing these across waves of tasks). */ -private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager { +private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + + if (!conf.getBoolean("spark.shuffle.spill", true)) { + logWarning( + "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." + + " Shuffle will continue to spill to disk when necessary.") + } private val fileShuffleBlockResolver = new FileShuffleBlockResolver(conf) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index d7fab351ca3b8..476cc1f303da7 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -19,11 +19,17 @@ package org.apache.spark.shuffle.sort import java.util.concurrent.ConcurrentHashMap -import org.apache.spark.{SparkConf, TaskContext, ShuffleDependency} +import org.apache.spark.{Logging, SparkConf, TaskContext, ShuffleDependency} import org.apache.spark.shuffle._ import org.apache.spark.shuffle.hash.HashShuffleReader -private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager { +private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + + if (!conf.getBoolean("spark.shuffle.spill", true)) { + logWarning( + "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." + + " Shuffle will continue to spill to disk when necessary.") + } private val indexShuffleBlockResolver = new IndexShuffleBlockResolver(conf) private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]() diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 31230d5978b2a..2a30f751ff03d 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -116,8 +116,6 @@ private[spark] class ExternalSorter[K, V, C]( private val ser = Serializer.getSerializer(serializer) private val serInstance = ser.newInstance() - private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) - // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 @@ -229,10 +227,6 @@ private[spark] class ExternalSorter[K, V, C]( * @param usingMap whether we're using a map or buffer as our current in-memory collection */ private def maybeSpillCollection(usingMap: Boolean): Unit = { - if (!spillingEnabled) { - return - } - var estimatedSize = 0L if (usingMap) { estimatedSize = map.estimateSize() diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 1110ca6051a40..1fd470cd3b01d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -147,7 +147,7 @@ class SparkSubmitSuite "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", "--name", "beauty", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -166,7 +166,7 @@ class SparkSubmitSuite mainClass should be ("org.apache.spark.deploy.yarn.Client") classpath should have length (0) sysProps("spark.app.name") should be ("beauty") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") sysProps("SPARK_SUBMIT") should be ("true") sysProps.keys should not contain ("spark.jars") } @@ -185,7 +185,7 @@ class SparkSubmitSuite "--archives", "archive1.txt,archive2.txt", "--num-executors", "6", "--name", "trill", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -206,7 +206,7 @@ class SparkSubmitSuite sysProps("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") sysProps("spark.jars") should include regex (".*one.jar,.*two.jar,.*three.jar,.*thejar.jar") sysProps("SPARK_SUBMIT") should be ("true") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") } test("handles standalone cluster mode") { @@ -229,7 +229,7 @@ class SparkSubmitSuite "--supervise", "--driver-memory", "4g", "--driver-cores", "5", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -253,9 +253,9 @@ class SparkSubmitSuite sysProps.keys should contain ("spark.driver.memory") sysProps.keys should contain ("spark.driver.cores") sysProps.keys should contain ("spark.driver.supervise") - sysProps.keys should contain ("spark.shuffle.spill") + sysProps.keys should contain ("spark.ui.enabled") sysProps.keys should contain ("spark.submit.deployMode") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") } test("handles standalone client mode") { @@ -266,7 +266,7 @@ class SparkSubmitSuite "--total-executor-cores", "5", "--class", "org.SomeClass", "--driver-memory", "4g", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -277,7 +277,7 @@ class SparkSubmitSuite classpath(0) should endWith ("thejar.jar") sysProps("spark.executor.memory") should be ("5g") sysProps("spark.cores.max") should be ("5") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") } test("handles mesos client mode") { @@ -288,7 +288,7 @@ class SparkSubmitSuite "--total-executor-cores", "5", "--class", "org.SomeClass", "--driver-memory", "4g", - "--conf", "spark.shuffle.spill=false", + "--conf", "spark.ui.enabled=false", "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) @@ -299,7 +299,7 @@ class SparkSubmitSuite classpath(0) should endWith ("thejar.jar") sysProps("spark.executor.memory") should be ("5g") sysProps("spark.cores.max") should be ("5") - sysProps("spark.shuffle.spill") should be ("false") + sysProps("spark.ui.enabled") should be ("false") } test("handles confs with flag equivalents") { diff --git a/docs/configuration.md b/docs/configuration.md index 3700051efb448..5ec097c78aa38 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -69,7 +69,7 @@ val sc = new SparkContext(new SparkConf()) Then, you can supply configuration values at runtime: {% highlight bash %} -./bin/spark-submit --name "My app" --master local[4] --conf spark.shuffle.spill=false +./bin/spark-submit --name "My app" --master local[4] --conf spark.eventLog.enabled=false --conf "spark.executor.extraJavaOptions=-XX:+PrintGCDetails -XX:+PrintGCTimeStamps" myApp.jar {% endhighlight %} @@ -449,8 +449,8 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.memoryFraction 0.2 - Fraction of Java heap to use for aggregation and cogroups during shuffles, if - spark.shuffle.spill is true. At any given time, the collective size of + Fraction of Java heap to use for aggregation and cogroups during shuffles. + At any given time, the collective size of all in-memory maps used for shuffles is bounded by this limit, beyond which the contents will begin to spill to disk. If spills are often, consider increasing this value at the expense of spark.storage.memoryFraction. @@ -483,14 +483,6 @@ Apart from these, the following properties are also available, and may be useful map-side aggregation and there are at most this many reduce partitions. - - spark.shuffle.spill - true - - If set to "true", limits the amount of memory used during reduces by spilling data out to disk. - This spilling threshold is specified by spark.shuffle.memoryFraction. - - spark.shuffle.spill.compress true diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 82d4243cc6b27..7ae9244c271e3 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1936,13 +1936,6 @@ that these options will be deprecated in future release as more optimizations ar Configures the number of partitions to use when shuffling data for joins or aggregations. - - spark.sql.planner.externalSort - true - - When true, performs sorts spilling to disk as needed otherwise sort each partition in memory. - - # Distributed SQL Engine diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index ab5aab1e115f7..73d7d9a5692a9 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -48,7 +48,7 @@ from pyspark.rddsampler import RDDSampler, RDDRangeSampler, RDDStratifiedSampler from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable -from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ +from pyspark.shuffle import Aggregator, ExternalMerger, \ get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync @@ -580,12 +580,11 @@ def repartitionAndSortWithinPartitions(self, numPartitions=None, partitionFunc=p if numPartitions is None: numPartitions = self._defaultReducePartitions() - spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == "true") memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) serializer = self._jrdd_deserializer def sortPartition(iterator): - sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted + sort = ExternalSorter(memory * 0.9, serializer).sorted return iter(sort(iterator, key=lambda k_v: keyfunc(k_v[0]), reverse=(not ascending))) return self.partitionBy(numPartitions, partitionFunc).mapPartitions(sortPartition, True) @@ -610,12 +609,11 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): if numPartitions is None: numPartitions = self._defaultReducePartitions() - spill = self._can_spill() memory = self._memory_limit() serializer = self._jrdd_deserializer def sortPartition(iterator): - sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted + sort = ExternalSorter(memory * 0.9, serializer).sorted return iter(sort(iterator, key=lambda kv: keyfunc(kv[0]), reverse=(not ascending))) if numPartitions == 1: @@ -1770,13 +1768,11 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numPartitions = self._defaultReducePartitions() serializer = self.ctx.serializer - spill = self._can_spill() memory = self._memory_limit() agg = Aggregator(createCombiner, mergeValue, mergeCombiners) def combineLocally(iterator): - merger = ExternalMerger(agg, memory * 0.9, serializer) \ - if spill else InMemoryMerger(agg) + merger = ExternalMerger(agg, memory * 0.9, serializer) merger.mergeValues(iterator) return merger.items() @@ -1784,8 +1780,7 @@ def combineLocally(iterator): shuffled = locally_combined.partitionBy(numPartitions) def _mergeCombiners(iterator): - merger = ExternalMerger(agg, memory, serializer) \ - if spill else InMemoryMerger(agg) + merger = ExternalMerger(agg, memory, serializer) merger.mergeCombiners(iterator) return merger.items() @@ -1824,9 +1819,6 @@ def createZero(): return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions) - def _can_spill(self): - return self.ctx._conf.get("spark.shuffle.spill", "True").lower() == "true" - def _memory_limit(self): return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) @@ -1857,14 +1849,12 @@ def mergeCombiners(a, b): a.extend(b) return a - spill = self._can_spill() memory = self._memory_limit() serializer = self._jrdd_deserializer agg = Aggregator(createCombiner, mergeValue, mergeCombiners) def combine(iterator): - merger = ExternalMerger(agg, memory * 0.9, serializer) \ - if spill else InMemoryMerger(agg) + merger = ExternalMerger(agg, memory * 0.9, serializer) merger.mergeValues(iterator) return merger.items() @@ -1872,8 +1862,7 @@ def combine(iterator): shuffled = locally_combined.partitionBy(numPartitions) def groupByKey(it): - merger = ExternalGroupBy(agg, memory, serializer)\ - if spill else InMemoryMerger(agg) + merger = ExternalGroupBy(agg, memory, serializer) merger.mergeCombiners(it) return merger.items() diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index b8118bdb7ca76..e974cda9fc3e1 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -131,36 +131,6 @@ def items(self): raise NotImplementedError -class InMemoryMerger(Merger): - - """ - In memory merger based on in-memory dict. - """ - - def __init__(self, aggregator): - Merger.__init__(self, aggregator) - self.data = {} - - def mergeValues(self, iterator): - """ Combine the items by creator and combiner """ - # speed up attributes lookup - d, creator = self.data, self.agg.createCombiner - comb = self.agg.mergeValue - for k, v in iterator: - d[k] = comb(d[k], v) if k in d else creator(v) - - def mergeCombiners(self, iterator): - """ Merge the combined items by mergeCombiner """ - # speed up attributes lookup - d, comb = self.data, self.agg.mergeCombiners - for k, v in iterator: - d[k] = comb(d[k], v) if k in d else v - - def items(self): - """ Return the merged items ad iterator """ - return iter(self.data.items()) - - def _compressed_serializer(self, serializer=None): # always use PickleSerializer to simplify implementation ser = PickleSerializer() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 647504c32f156..f11aaf001c8df 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -62,7 +62,7 @@ CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer, \ PairDeserializer, CartesianDeserializer, AutoBatchedSerializer, AutoSerializer, \ FlattenedValuesSerializer -from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter +from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter from pyspark import shuffle from pyspark.profiler import BasicProfiler @@ -95,17 +95,6 @@ def setUp(self): lambda x, y: x.append(y) or x, lambda x, y: x.extend(y) or x) - def test_in_memory(self): - m = InMemoryMerger(self.agg) - m.mergeValues(self.data) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) - - m = InMemoryMerger(self.agg) - m.mergeCombiners(map(lambda x_y: (x_y[0], [x_y[1]]), self.data)) - self.assertEqual(sum(sum(v) for k, v in m.items()), - sum(xrange(self.N))) - def test_small_dataset(self): m = ExternalMerger(self.agg, 1000) m.mergeValues(self.data) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 9de75f4c4d084..b9fb90d964206 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -330,11 +330,6 @@ private[spark] object SQLConf { // Options that control which operators can be chosen by the query planner. These should be // considered hints and may be ignored by future versions of Spark SQL. - val EXTERNAL_SORT = booleanConf("spark.sql.planner.externalSort", - defaultValue = Some(true), - doc = "When true, performs sorts spilling to disk as needed otherwise sort each partition in" + - " memory.") - val SORTMERGE_JOIN = booleanConf("spark.sql.planner.sortMergeJoin", defaultValue = Some(true), doc = "When true, use sort merge join (as opposed to hash join) by default for large joins.") @@ -422,6 +417,7 @@ private[spark] object SQLConf { object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" + val EXTERNAL_SORT = "spark.sql.planner.externalSort" } } @@ -476,8 +472,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def metastorePartitionPruning: Boolean = getConf(HIVE_METASTORE_PARTITION_PRUNING) - private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT) - private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN) private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, getConf(TUNGSTEN_ENABLED)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5e40d77689045..41b215c79296a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -312,8 +312,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { if (sqlContext.conf.unsafeEnabled && sqlContext.conf.codegenEnabled && TungstenSort.supportsSchema(child.schema)) { execution.TungstenSort(sortExprs, global, child) - } else if (sqlContext.conf.externalSortEnabled) { - execution.ExternalSort(sortExprs, global, child) } else { execution.Sort(sortExprs, global, child) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 95209e6634519..af28e2dfa4186 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -105,6 +105,15 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm } (keyValueOutput, runFunc) + case Some((SQLConf.Deprecated.EXTERNAL_SORT, Some(value))) => + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.EXTERNAL_SORT} is deprecated and will be ignored. " + + s"External sort will continue to be used.") + Seq(Row(SQLConf.Deprecated.EXTERNAL_SORT, "true")) + } + (keyValueOutput, runFunc) + // Configures a single property. case Some((key, Some(value))) => val runFunc = (sqlContext: SQLContext) => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index 40ef7c3b53530..27f26245a5ef0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -31,38 +31,12 @@ import org.apache.spark.{SparkEnv, InternalAccumulator, TaskContext} // This file defines various sort operators. //////////////////////////////////////////////////////////////////////////////////////////////////// - -/** - * Performs a sort on-heap. - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - */ -case class Sort( - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan) - extends UnaryNode { - override def requiredChildDistribution: Seq[Distribution] = - if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - child.execute().mapPartitions( { iterator => - val ordering = newOrdering(sortOrder, child.output) - iterator.map(_.copy()).toArray.sorted(ordering).iterator - }, preservesPartitioning = true) - } - - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder -} - /** * Performs a sort, spilling to disk as needed. * @param global when true performs a global sort of all partitions by shuffling the data first * if necessary. */ -case class ExternalSort( +case class Sort( sortOrder: Seq[SortOrder], global: Boolean, child: SparkPlan) @@ -93,7 +67,7 @@ case class ExternalSort( } /** - * Optimized version of [[ExternalSort]] that operates on binary data (implemented as part of + * Optimized version of [[Sort]] that operates on binary data (implemented as part of * Project Tungsten). * * @param global when true performs a global sort of all partitions by shuffling the data first diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index f9981356f364f..05b4127cbcaff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -581,28 +581,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { mapData.collect().sortBy(_.data(1)).reverse.map(Row.fromTuple).toSeq) } - test("sorting") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "false") { - sortTest() - } - } - test("external sorting") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "true") { - sortTest() - } - } - - test("SPARK-6927 sorting with codegen on") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "false", - SQLConf.CODEGEN_ENABLED.key -> "true") { - sortTest() - } + sortTest() } test("SPARK-6927 external sorting with codegen on") { - withSQLConf(SQLConf.EXTERNAL_SORT.key -> "true", - SQLConf.CODEGEN_ENABLED.key -> "true") { + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { sortTest() } } @@ -1731,10 +1715,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("external sorting updates peak execution memory") { - withSQLConf((SQLConf.EXTERNAL_SORT.key, "true")) { - AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") { - sortTest() - } + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") { + sortTest() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 4492e37ad01ff..5dc37e5c3c238 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -32,7 +32,7 @@ class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { case c: ConvertToSafe => c } - private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) + private val outputsSafe = Sort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) assert(!outputsSafe.outputsUnsafeRows) private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) assert(outputsUnsafe.outputsUnsafeRows) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index 3073d492e613b..847c188a30333 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -36,13 +36,13 @@ class SortSuite extends SparkPlanTest with SharedSQLContext { checkAnswer( input.toDF("a", "b", "c"), - ExternalSort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan), + Sort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan), input.sortBy(t => (t._1, t._2)).map(Row.fromTuple), sortAnswers = false) checkAnswer( input.toDF("a", "b", "c"), - ExternalSort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan), + Sort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan), input.sortBy(t => (t._2, t._1)).map(Row.fromTuple), sortAnswers = false) } From 1aa9e50256988533fa54584b49dbc408a14438ee Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 20 Sep 2015 16:05:12 -0700 Subject: [PATCH 355/802] [SPARK-5905] [MLLIB] Note requirements for certain RowMatrix methods in docs Note methods that fail for cols > 65535; note that SVD does not require n >= m CC mengxr Author: Sean Owen Closes #8839 from srowen/SPARK-5905. --- .../spark/mllib/linalg/distributed/RowMatrix.scala | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index e55ef26858adb..7c7d900af3d5a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -109,7 +109,8 @@ class RowMatrix @Since("1.0.0") ( } /** - * Computes the Gramian matrix `A^T A`. + * Computes the Gramian matrix `A^T A`. Note that this cannot be computed on matrices with + * more than 65535 columns. */ @Since("1.0.0") def computeGramianMatrix(): Matrix = { @@ -150,7 +151,8 @@ class RowMatrix @Since("1.0.0") ( * - s is a Vector of size k, holding the singular values in descending order, * - V is a Matrix of size n x k that satisfies V' * V = eye(k). * - * We assume n is smaller than m. The singular values and the right singular vectors are derived + * We assume n is smaller than m, though this is not strictly required. + * The singular values and the right singular vectors are derived * from the eigenvalues and the eigenvectors of the Gramian matrix A' * A. U, the matrix * storing the right singular vectors, is computed via matrix multiplication as * U = A * (V * S^-1^), if requested by user. The actual method to use is determined @@ -320,7 +322,8 @@ class RowMatrix @Since("1.0.0") ( } /** - * Computes the covariance matrix, treating each row as an observation. + * Computes the covariance matrix, treating each row as an observation. Note that this cannot + * be computed on matrices with more than 65535 columns. * @return a local dense matrix of size n x n */ @Since("1.0.0") @@ -374,6 +377,8 @@ class RowMatrix @Since("1.0.0") ( * The row data do not need to be "centered" first; it is not necessary for * the mean of each column to be 0. * + * Note that this cannot be computed on matrices with more than 65535 columns. + * * @param k number of top principal components. * @return a matrix of size n-by-k, whose columns are principal components */ From 0c498717ba9622b6c889e701e8eed5ef9215c030 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Sun, 20 Sep 2015 16:16:31 -0700 Subject: [PATCH 356/802] [SPARK-10715] [ML] Duplicate initialization flag in WeightedLeastSquare There are duplicate set of initialization flag in `WeightedLeastSquares#add`. `initialized` is already set in `init(Int)`. Author: lewuathe Closes #8837 from Lewuathe/duplicate-initialization-flag. --- .../scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 0ff8931b0bab4..4374e99631560 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -193,7 +193,6 @@ private[ml] object WeightedLeastSquares { val ak = a.size if (!initialized) { init(ak) - initialized = true } assert(ak == k, s"Dimension mismatch. Expect vectors of size $k but got $ak.") count += 1L From 01440395176bdbb2662480f03b27851cb860f385 Mon Sep 17 00:00:00 2001 From: vinodkc Date: Sun, 20 Sep 2015 22:55:24 -0700 Subject: [PATCH 357/802] [SPARK-10631] [DOCUMENTATION, MLLIB, PYSPARK] Added documentation for few APIs There are some missing API docs in pyspark.mllib.linalg.Vector (including DenseVector and SparseVector). We should add them based on their Scala counterparts. Author: vinodkc Closes #8834 from vinodkc/fix_SPARK-10631. --- python/pyspark/mllib/linalg/__init__.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 4829acb16ed8a..f929e3e96fbe2 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -301,11 +301,14 @@ def __reduce__(self): return DenseVector, (self.array.tostring(),) def numNonzeros(self): + """ + Number of nonzero elements. This scans all active values and count non zeros + """ return np.count_nonzero(self.array) def norm(self, p): """ - Calculte the norm of a DenseVector. + Calculates the norm of a DenseVector. >>> a = DenseVector([0, -1, 2, -3]) >>> a.norm(2) @@ -397,10 +400,16 @@ def squared_distance(self, other): return np.dot(diff, diff) def toArray(self): + """ + Returns an numpy.ndarray + """ return self.array @property def values(self): + """ + Returns a list of values + """ return self.array def __getitem__(self, item): @@ -479,8 +488,8 @@ def __init__(self, size, *args): :param size: Size of the vector. :param args: Active entries, as a dictionary {index: value, ...}, - a list of tuples [(index, value), ...], or a list of strictly i - ncreasing indices and a list of corresponding values [index, ...], + a list of tuples [(index, value), ...], or a list of strictly + increasing indices and a list of corresponding values [index, ...], [value, ...]. Inactive entries are treated as zeros. >>> SparseVector(4, {1: 1.0, 3: 5.5}) @@ -521,11 +530,14 @@ def __init__(self, size, *args): raise TypeError("indices array must be sorted") def numNonzeros(self): + """ + Number of nonzero elements. This scans all active values and count non zeros. + """ return np.count_nonzero(self.values) def norm(self, p): """ - Calculte the norm of a SparseVector. + Calculates the norm of a SparseVector. >>> a = SparseVector(4, [0, 1], [3., -4.]) >>> a.norm(1) @@ -797,7 +809,7 @@ def sparse(size, *args): values (sorted by index). :param size: Size of the vector. - :param args: Non-zero entries, as a dictionary, list of tupes, + :param args: Non-zero entries, as a dictionary, list of tuples, or two sorted lists containing indices and values. >>> Vectors.sparse(4, {1: 1.0, 3: 5.5}) From 20a61dbd9b57957fcc5b58ef8935533914172b07 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 21 Sep 2015 18:53:28 +0100 Subject: [PATCH 358/802] [SPARK-10626] [MLLIB] create java friendly method for random rdd SPARK-3136 added a large number of functions for creating Java RandomRDDs, but for people that want to use custom RandomDataGenerators we should make a Java friendly method. Author: Holden Karau Closes #8782 from holdenk/SPARK-10626-create-java-friendly-method-for-randomRDD. --- .../spark/mllib/random/RandomRDDs.scala | 52 ++++++++++++++++++- .../mllib/random/JavaRandomRDDsSuite.java | 30 +++++++++++ 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index 4dd5ea214d678..f8ff26b5795be 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -22,6 +22,7 @@ import scala.reflect.ClassTag import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD, JavaSparkContext} +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.rdd.{RandomRDD, RandomVectorRDD} import org.apache.spark.rdd.RDD @@ -381,7 +382,7 @@ object RandomRDDs { * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). - * @return RDD[Double] comprised of `i.i.d.` samples produced by generator. + * @return RDD[T] comprised of `i.i.d.` samples produced by generator. */ @DeveloperApi @Since("1.1.0") @@ -394,6 +395,55 @@ object RandomRDDs { new RandomRDD[T](sc, size, numPartitionsOrDefault(sc, numPartitions), generator, seed) } + /** + * :: DeveloperApi :: + * Generates an RDD comprised of `i.i.d.` samples produced by the input RandomDataGenerator. + * + * @param jsc JavaSparkContext used to create the RDD. + * @param generator RandomDataGenerator used to populate the RDD. + * @param size Size of the RDD. + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). + * @param seed Random seed (default: a random long integer). + * @return RDD[T] comprised of `i.i.d.` samples produced by generator. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaRDD[T]( + jsc: JavaSparkContext, + generator: RandomDataGenerator[T], + size: Long, + numPartitions: Int, + seed: Long): JavaRDD[T] = { + implicit val ctag: ClassTag[T] = fakeClassTag + val rdd = randomRDD(jsc.sc, generator, size, numPartitions, seed) + JavaRDD.fromRDD(rdd) + } + + /** + * [[RandomRDDs#randomJavaRDD]] with the default seed. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaRDD[T]( + jsc: JavaSparkContext, + generator: RandomDataGenerator[T], + size: Long, + numPartitions: Int): JavaRDD[T] = { + randomJavaRDD(jsc, generator, size, numPartitions, Utils.random.nextLong()) + } + + /** + * [[RandomRDDs#randomJavaRDD]] with the default seed & numPartitions + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaRDD[T]( + jsc: JavaSparkContext, + generator: RandomDataGenerator[T], + size: Long): JavaRDD[T] = { + randomJavaRDD(jsc, generator, size, 0); + } + // TODO Generate RDD[Vector] from multivariate distributions. /** diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java index 33d81b1e9592b..fce5f6712f462 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.mllib.random; +import java.io.Serializable; import java.util.Arrays; import org.apache.spark.api.java.JavaRDD; @@ -231,4 +232,33 @@ public void testGammaVectorRDD() { } } + @Test + public void testArbitrary() { + long size = 10; + long seed = 1L; + int numPartitions = 0; + StringGenerator gen = new StringGenerator(); + JavaRDD rdd1 = randomJavaRDD(sc, gen, size); + JavaRDD rdd2 = randomJavaRDD(sc, gen, size, numPartitions); + JavaRDD rdd3 = randomJavaRDD(sc, gen, size, numPartitions, seed); + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(size, rdd.count()); + Assert.assertEquals(2, rdd.first().length()); + } + } +} + +// This is just a test generator, it always returns a string of 42 +class StringGenerator implements RandomDataGenerator, Serializable { + @Override + public String nextValue() { + return "42"; + } + @Override + public StringGenerator copy() { + return new StringGenerator(); + } + @Override + public void setSeed(long seed) { + } } From ebbf85f07bb8de0d566f1ae4b41f26421180bebe Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 21 Sep 2015 11:39:04 -0700 Subject: [PATCH 359/802] [SPARK-7989] [SPARK-10651] [CORE] [TESTS] Increase timeout to fix flaky tests I noticed only one block manager registered with master in an unsuccessful build (https://amplab.cs.berkeley.edu/jenkins/job/Spark-Master-SBT/AMPLAB_JENKINS_BUILD_PROFILE=hadoop2.2,label=spark-test/3534/) ``` 15/09/16 13:02:30.981 pool-1-thread-1-ScalaTest-running-BroadcastSuite INFO SparkContext: Running Spark version 1.6.0-SNAPSHOT ... 15/09/16 13:02:38.133 sparkDriver-akka.actor.default-dispatcher-19 INFO BlockManagerMasterEndpoint: Registering block manager localhost:48196 with 530.3 MB RAM, BlockManagerId(0, localhost, 48196) ``` In addition, the first block manager needed 7+ seconds to start. But the test expected 2 block managers so it failed. However, there was no exception in this log file. So I checked a successful build (https://amplab.cs.berkeley.edu/jenkins/job/Spark-Master-SBT/3536/AMPLAB_JENKINS_BUILD_PROFILE=hadoop2.2,label=spark-test/) and it needed 4-5 seconds to set up the local cluster: ``` 15/09/16 18:11:27.738 sparkWorker1-akka.actor.default-dispatcher-5 INFO Worker: Running Spark version 1.6.0-SNAPSHOT ... 15/09/16 18:11:30.838 sparkDriver-akka.actor.default-dispatcher-20 INFO BlockManagerMasterEndpoint: Registering block manager localhost:54202 with 530.3 MB RAM, BlockManagerId(1, localhost, 54202) 15/09/16 18:11:32.112 sparkDriver-akka.actor.default-dispatcher-20 INFO BlockManagerMasterEndpoint: Registering block manager localhost:32955 with 530.3 MB RAM, BlockManagerId(0, localhost, 32955) ``` In this build, the first block manager needed only 3+ seconds to start. Comparing these two builds, I guess it's possible that the local cluster in `BroadcastSuite` cannot be ready in 10 seconds if the Jenkins worker is busy. So I just increased the timeout to 60 seconds to see if this can fix the issue. Author: zsxwing Closes #8813 from zsxwing/fix-BroadcastSuite. --- .../scala/org/apache/spark/ExternalShuffleServiceSuite.scala | 2 +- .../test/scala/org/apache/spark/broadcast/BroadcastSuite.scala | 2 +- .../apache/spark/scheduler/SparkListenerWithClusterSuite.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index e846a72c888c6..231f4631e0a47 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -61,7 +61,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { // local blocks from the local BlockManager and won't send requests to ExternalShuffleService. // In this case, we won't receive FetchFailed. And it will make this test fail. // Therefore, we should wait until all slaves are up - sc.jobProgressListener.waitUntilExecutorsUp(2, 10000) + sc.jobProgressListener.waitUntilExecutorsUp(2, 60000) val rdd = sc.parallelize(0 until 1000, 10).map(i => (i, 1)).reduceByKey(_ + _) diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index fb7a8ae3f9d41..ba21075ce6be5 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -311,7 +311,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", broadcastConf) // Wait until all salves are up try { - _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 10000) + _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 60000) _sc } catch { case e: Throwable => diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala index d1e23ed527ff1..9fa8859382911 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala @@ -43,7 +43,7 @@ class SparkListenerWithClusterSuite extends SparkFunSuite with LocalSparkContext // This test will check if the number of executors received by "SparkListener" is same as the // number of all executors, so we need to wait until all executors are up - sc.jobProgressListener.waitUntilExecutorsUp(2, 10000) + sc.jobProgressListener.waitUntilExecutorsUp(2, 60000) val rdd1 = sc.parallelize(1 to 100, 4) val rdd2 = rdd1.map(_.toString) From ca9fe540fe04e2e230d1e76526b5502bab152914 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Mon, 21 Sep 2015 19:46:39 +0100 Subject: [PATCH 360/802] [SPARK-10662] [DOCS] Code snippets are not properly formatted in tables * Backticks are processed properly in Spark Properties table * Removed unnecessary spaces * See http://people.apache.org/~pwendell/spark-nightly/spark-master-docs/latest/running-on-yarn.html Author: Jacek Laskowski Closes #8795 from jaceklaskowski/docs-yarn-formatting. --- docs/configuration.md | 97 +++++++++++++++-------------- docs/programming-guide.md | 100 +++++++++++++++--------------- docs/running-on-mesos.md | 14 ++--- docs/running-on-yarn.md | 106 ++++++++++++++++---------------- docs/sql-programming-guide.md | 16 ++--- docs/submitting-applications.md | 8 +-- 6 files changed, 171 insertions(+), 170 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 5ec097c78aa38..b22587c70316b 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -34,20 +34,20 @@ val conf = new SparkConf() val sc = new SparkContext(conf) {% endhighlight %} -Note that we can have more than 1 thread in local mode, and in cases like Spark Streaming, we may +Note that we can have more than 1 thread in local mode, and in cases like Spark Streaming, we may actually require one to prevent any sort of starvation issues. -Properties that specify some time duration should be configured with a unit of time. +Properties that specify some time duration should be configured with a unit of time. The following format is accepted: - + 25ms (milliseconds) 5s (seconds) 10m or 10min (minutes) 3h (hours) 5d (days) 1y (years) - - + + Properties that specify a byte size should be configured with a unit of size. The following format is accepted: @@ -140,7 +140,7 @@ of the most common options to set are: Amount of memory to use for the driver process, i.e. where SparkContext is initialized. (e.g. 1g, 2g). - +
    Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. Instead, please set this through the --driver-memory command line option @@ -207,7 +207,7 @@ Apart from these, the following properties are also available, and may be useful
    Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. - Instead, please set this through the --driver-class-path command line option or in + Instead, please set this through the --driver-class-path command line option or in your default properties file. @@ -216,10 +216,10 @@ Apart from these, the following properties are also available, and may be useful (none) A string of extra JVM options to pass to the driver. For instance, GC settings or other logging. - +
    Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. - Instead, please set this through the --driver-java-options command line option or in + Instead, please set this through the --driver-java-options command line option or in your default properties file. @@ -228,10 +228,10 @@ Apart from these, the following properties are also available, and may be useful (none) Set a special library path to use when launching the driver JVM. - +
    Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. - Instead, please set this through the --driver-library-path command line option or in + Instead, please set this through the --driver-library-path command line option or in your default properties file. @@ -242,7 +242,7 @@ Apart from these, the following properties are also available, and may be useful (Experimental) Whether to give user-added jars precedence over Spark's own jars when loading classes in the the driver. This feature can be used to mitigate conflicts between Spark's dependencies and user dependencies. It is currently an experimental feature. - + This is used in cluster mode only. @@ -250,8 +250,8 @@ Apart from these, the following properties are also available, and may be useful spark.executor.extraClassPath (none) - Extra classpath entries to prepend to the classpath of executors. This exists primarily for - backwards-compatibility with older versions of Spark. Users typically should not need to set + Extra classpath entries to prepend to the classpath of executors. This exists primarily for + backwards-compatibility with older versions of Spark. Users typically should not need to set this option. @@ -259,9 +259,9 @@ Apart from these, the following properties are also available, and may be useful spark.executor.extraJavaOptions (none) - A string of extra JVM options to pass to executors. For instance, GC settings or other logging. - Note that it is illegal to set Spark properties or heap size settings with this option. Spark - properties should be set using a SparkConf object or the spark-defaults.conf file used with the + A string of extra JVM options to pass to executors. For instance, GC settings or other logging. + Note that it is illegal to set Spark properties or heap size settings with this option. Spark + properties should be set using a SparkConf object or the spark-defaults.conf file used with the spark-submit script. Heap size settings can be set with spark.executor.memory. @@ -305,7 +305,7 @@ Apart from these, the following properties are also available, and may be useful daily Set the time interval by which the executor logs will be rolled over. - Rolling is disabled by default. Valid values are `daily`, `hourly`, `minutely` or + Rolling is disabled by default. Valid values are daily, hourly, minutely or any interval in seconds. See spark.executor.logs.rolling.maxRetainedFiles for automatic cleaning of old logs. @@ -330,13 +330,13 @@ Apart from these, the following properties are also available, and may be useful spark.python.profile false - Enable profiling in Python worker, the profile result will show up by `sc.show_profiles()`, + Enable profiling in Python worker, the profile result will show up by sc.show_profiles(), or it will be displayed before the driver exiting. It also can be dumped into disk by - `sc.dump_profiles(path)`. If some of the profile results had been displayed manually, + sc.dump_profiles(path). If some of the profile results had been displayed manually, they will not be displayed automatically before driver exiting. - By default the `pyspark.profiler.BasicProfiler` will be used, but this can be overridden by - passing a profiler class in as a parameter to the `SparkContext` constructor. + By default the pyspark.profiler.BasicProfiler will be used, but this can be overridden by + passing a profiler class in as a parameter to the SparkContext constructor. @@ -460,11 +460,11 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.service.enabled false - Enables the external shuffle service. This service preserves the shuffle files written by - executors so the executors can be safely removed. This must be enabled if + Enables the external shuffle service. This service preserves the shuffle files written by + executors so the executors can be safely removed. This must be enabled if spark.dynamicAllocation.enabled is "true". The external shuffle service must be set up in order to enable it. See - dynamic allocation + dynamic allocation configuration and setup documentation for more information. @@ -747,9 +747,9 @@ Apart from these, the following properties are also available, and may be useful 1 in YARN mode, all the available cores on the worker in standalone mode. The number of cores to use on each executor. For YARN and standalone mode only. - - In standalone mode, setting this parameter allows an application to run multiple executors on - the same worker, provided that there are enough cores on that worker. Otherwise, only one + + In standalone mode, setting this parameter allows an application to run multiple executors on + the same worker, provided that there are enough cores on that worker. Otherwise, only one executor per application will run on each worker. @@ -893,14 +893,14 @@ Apart from these, the following properties are also available, and may be useful spark.akka.heartbeat.interval 1000s - This is set to a larger value to disable the transport failure detector that comes built in to - Akka. It can be enabled again, if you plan to use this feature (Not recommended). A larger - interval value reduces network overhead and a smaller value ( ~ 1 s) might be more - informative for Akka's failure detector. Tune this in combination of `spark.akka.heartbeat.pauses` - if you need to. A likely positive use case for using failure detector would be: a sensistive - failure detector can help evict rogue executors quickly. However this is usually not the case - as GC pauses and network lags are expected in a real Spark cluster. Apart from that enabling - this leads to a lot of exchanges of heart beats between nodes leading to flooding the network + This is set to a larger value to disable the transport failure detector that comes built in to + Akka. It can be enabled again, if you plan to use this feature (Not recommended). A larger + interval value reduces network overhead and a smaller value ( ~ 1 s) might be more + informative for Akka's failure detector. Tune this in combination of spark.akka.heartbeat.pauses + if you need to. A likely positive use case for using failure detector would be: a sensistive + failure detector can help evict rogue executors quickly. However this is usually not the case + as GC pauses and network lags are expected in a real Spark cluster. Apart from that enabling + this leads to a lot of exchanges of heart beats between nodes leading to flooding the network with those. @@ -909,9 +909,9 @@ Apart from these, the following properties are also available, and may be useful 6000s This is set to a larger value to disable the transport failure detector that comes built in to Akka. - It can be enabled again, if you plan to use this feature (Not recommended). Acceptable heart + It can be enabled again, if you plan to use this feature (Not recommended). Acceptable heart beat pause for Akka. This can be used to control sensitivity to GC pauses. Tune - this along with `spark.akka.heartbeat.interval` if you need to. + this along with spark.akka.heartbeat.interval if you need to. @@ -978,7 +978,7 @@ Apart from these, the following properties are also available, and may be useful spark.network.timeout 120s - Default timeout for all network interactions. This config will be used in place of + Default timeout for all network interactions. This config will be used in place of spark.core.connection.ack.wait.timeout, spark.akka.timeout, spark.storage.blockManagerSlaveTimeoutMs, spark.shuffle.io.connectionTimeout, spark.rpc.askTimeout or @@ -991,8 +991,8 @@ Apart from these, the following properties are also available, and may be useful Maximum number of retries when binding to a port before giving up. When a port is given a specific value (non 0), each subsequent retry will - increment the port used in the previous attempt by 1 before retrying. This - essentially allows it to try a range of ports from the start port specified + increment the port used in the previous attempt by 1 before retrying. This + essentially allows it to try a range of ports from the start port specified to port + maxRetries. @@ -1191,7 +1191,7 @@ Apart from these, the following properties are also available, and may be useful spark.dynamicAllocation.executorIdleTimeout 60s - If dynamic allocation is enabled and an executor has been idle for more than this duration, + If dynamic allocation is enabled and an executor has been idle for more than this duration, the executor will be removed. For more detail, see this description. @@ -1424,11 +1424,11 @@ Apart from these, the following properties are also available, and may be useful false Enables or disables Spark Streaming's internal backpressure mechanism (since 1.5). - This enables the Spark Streaming to control the receiving rate based on the + This enables the Spark Streaming to control the receiving rate based on the current batch scheduling delays and processing times so that the system receives - only as fast as the system can process. Internally, this dynamically sets the + only as fast as the system can process. Internally, this dynamically sets the maximum receiving rate of receivers. This rate is upper bounded by the values - `spark.streaming.receiver.maxRate` and `spark.streaming.kafka.maxRatePerPartition` + spark.streaming.receiver.maxRate and spark.streaming.kafka.maxRatePerPartition if they are set (see below). @@ -1542,15 +1542,15 @@ The following variables can be set in `spark-env.sh`: Environment VariableMeaning JAVA_HOME - Location where Java is installed (if it's not on your default `PATH`). + Location where Java is installed (if it's not on your default PATH). PYSPARK_PYTHON - Python binary executable to use for PySpark in both driver and workers (default is `python`). + Python binary executable to use for PySpark in both driver and workers (default is python). PYSPARK_DRIVER_PYTHON - Python binary executable to use for PySpark in driver only (default is PYSPARK_PYTHON). + Python binary executable to use for PySpark in driver only (default is PYSPARK_PYTHON). SPARK_LOCAL_IP @@ -1580,4 +1580,3 @@ Spark uses [log4j](http://logging.apache.org/log4j/) for logging. You can config To specify a different configuration directory other than the default "SPARK_HOME/conf", you can set SPARK_CONF_DIR. Spark will use the the configuration files (spark-defaults.conf, spark-env.sh, log4j.properties, etc) from this directory. - diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 4cf83bb392636..8ad238315f12c 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -182,8 +182,8 @@ in-process. In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `--master` argument, and you can add JARs to the classpath -by passing a comma-separated list to the `--jars` argument. You can also add dependencies -(e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates +by passing a comma-separated list to the `--jars` argument. You can also add dependencies +(e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates to the `--packages` argument. Any additional repositories where dependencies might exist (e.g. SonaType) can be passed to the `--repositories` argument. For example, to run `bin/spark-shell` on exactly four cores, use: @@ -217,7 +217,7 @@ context connects to using the `--master` argument, and you can add Python .zip, to the runtime path by passing a comma-separated list to `--py-files`. You can also add dependencies (e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates to the `--packages` argument. Any additional repositories where dependencies might exist (e.g. SonaType) -can be passed to the `--repositories` argument. Any python dependencies a Spark Package has (listed in +can be passed to the `--repositories` argument. Any python dependencies a Spark Package has (listed in the requirements.txt of that package) must be manually installed using pip when necessary. For example, to run `bin/pyspark` on exactly four cores, use: @@ -249,8 +249,8 @@ the [IPython Notebook](http://ipython.org/notebook.html) with PyLab plot support $ PYSPARK_DRIVER_PYTHON=ipython PYSPARK_DRIVER_PYTHON_OPTS="notebook" ./bin/pyspark {% endhighlight %} -After the IPython Notebook server is launched, you can create a new "Python 2" notebook from -the "Files" tab. Inside the notebook, you can input the command `%pylab inline` as part of +After the IPython Notebook server is launched, you can create a new "Python 2" notebook from +the "Files" tab. Inside the notebook, you can input the command `%pylab inline` as part of your notebook before you start to try Spark from the IPython notebook.
@@ -418,9 +418,9 @@ Apart from text files, Spark's Python API also supports several other data forma **Writable Support** -PySpark SequenceFile support loads an RDD of key-value pairs within Java, converts Writables to base Java types, and pickles the -resulting Java objects using [Pyrolite](https://github.com/irmen/Pyrolite/). When saving an RDD of key-value pairs to SequenceFile, -PySpark does the reverse. It unpickles Python objects into Java objects and then converts them to Writables. The following +PySpark SequenceFile support loads an RDD of key-value pairs within Java, converts Writables to base Java types, and pickles the +resulting Java objects using [Pyrolite](https://github.com/irmen/Pyrolite/). When saving an RDD of key-value pairs to SequenceFile, +PySpark does the reverse. It unpickles Python objects into Java objects and then converts them to Writables. The following Writables are automatically converted: @@ -435,9 +435,9 @@ Writables are automatically converted:
MapWritabledict
-Arrays are not handled out-of-the-box. Users need to specify custom `ArrayWritable` subtypes when reading or writing. When writing, -users also need to specify custom converters that convert arrays to custom `ArrayWritable` subtypes. When reading, the default -converter will convert custom `ArrayWritable` subtypes to Java `Object[]`, which then get pickled to Python tuples. To get +Arrays are not handled out-of-the-box. Users need to specify custom `ArrayWritable` subtypes when reading or writing. When writing, +users also need to specify custom converters that convert arrays to custom `ArrayWritable` subtypes. When reading, the default +converter will convert custom `ArrayWritable` subtypes to Java `Object[]`, which then get pickled to Python tuples. To get Python `array.array` for arrays of primitive types, users need to specify custom converters. **Saving and Loading SequenceFiles** @@ -454,7 +454,7 @@ classes can be specified, but for standard Writables this is not required. **Saving and Loading Other Hadoop Input/Output Formats** -PySpark can also read any Hadoop InputFormat or write any Hadoop OutputFormat, for both 'new' and 'old' Hadoop MapReduce APIs. +PySpark can also read any Hadoop InputFormat or write any Hadoop OutputFormat, for both 'new' and 'old' Hadoop MapReduce APIs. If required, a Hadoop configuration can be passed in as a Python dict. Here is an example using the Elasticsearch ESInputFormat: @@ -474,15 +474,15 @@ Note that, if the InputFormat simply depends on a Hadoop configuration and/or in the key and value classes can easily be converted according to the above table, then this approach should work well for such cases. -If you have custom serialized binary data (such as loading data from Cassandra / HBase), then you will first need to +If you have custom serialized binary data (such as loading data from Cassandra / HBase), then you will first need to transform that data on the Scala/Java side to something which can be handled by Pyrolite's pickler. -A [Converter](api/scala/index.html#org.apache.spark.api.python.Converter) trait is provided -for this. Simply extend this trait and implement your transformation code in the ```convert``` -method. Remember to ensure that this class, along with any dependencies required to access your ```InputFormat```, are packaged into your Spark job jar and included on the PySpark +A [Converter](api/scala/index.html#org.apache.spark.api.python.Converter) trait is provided +for this. Simply extend this trait and implement your transformation code in the ```convert``` +method. Remember to ensure that this class, along with any dependencies required to access your ```InputFormat```, are packaged into your Spark job jar and included on the PySpark classpath. -See the [Python examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python) and -the [Converter examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/pythonconverters) +See the [Python examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python) and +the [Converter examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/pythonconverters) for examples of using Cassandra / HBase ```InputFormat``` and ```OutputFormat``` with custom converters.
@@ -758,7 +758,7 @@ One of the harder things about Spark is understanding the scope and life cycle o #### Example -Consider the naive RDD element sum below, which behaves completely differently depending on whether execution is happening within the same JVM. A common example of this is when running Spark in `local` mode (`--master = local[n]`) versus deploying a Spark application to a cluster (e.g. via spark-submit to YARN): +Consider the naive RDD element sum below, which behaves completely differently depending on whether execution is happening within the same JVM. A common example of this is when running Spark in `local` mode (`--master = local[n]`) versus deploying a Spark application to a cluster (e.g. via spark-submit to YARN):
@@ -777,7 +777,7 @@ println("Counter value: " + counter)
{% highlight java %} int counter = 0; -JavaRDD rdd = sc.parallelize(data); +JavaRDD rdd = sc.parallelize(data); // Wrong: Don't do this!! rdd.foreach(x -> counter += x); @@ -803,7 +803,7 @@ print("Counter value: " + counter) #### Local vs. cluster modes -The primary challenge is that the behavior of the above code is undefined. In local mode with a single JVM, the above code will sum the values within the RDD and store it in **counter**. This is because both the RDD and the variable **counter** are in the same memory space on the driver node. +The primary challenge is that the behavior of the above code is undefined. In local mode with a single JVM, the above code will sum the values within the RDD and store it in **counter**. This is because both the RDD and the variable **counter** are in the same memory space on the driver node. However, in `cluster` mode, what happens is more complicated, and the above may not work as intended. To execute jobs, Spark breaks up the processing of RDD operations into tasks - each of which is operated on by an executor. Prior to execution, Spark computes the **closure**. The closure is those variables and methods which must be visible for the executor to perform its computations on the RDD (in this case `foreach()`). This closure is serialized and sent to each executor. In `local` mode, there is only the one executors so everything shares the same closure. In other modes however, this is not the case and the executors running on seperate worker nodes each have their own copy of the closure. @@ -813,9 +813,9 @@ To ensure well-defined behavior in these sorts of scenarios one should use an [` In general, closures - constructs like loops or locally defined methods, should not be used to mutate some global state. Spark does not define or guarantee the behavior of mutations to objects referenced from outside of closures. Some code that does this may work in local mode, but that's just by accident and such code will not behave as expected in distributed mode. Use an Accumulator instead if some global aggregation is needed. -#### Printing elements of an RDD +#### Printing elements of an RDD Another common idiom is attempting to print out the elements of an RDD using `rdd.foreach(println)` or `rdd.map(println)`. On a single machine, this will generate the expected output and print all the RDD's elements. However, in `cluster` mode, the output to `stdout` being called by the executors is now writing to the executor's `stdout` instead, not the one on the driver, so `stdout` on the driver won't show these! To print all elements on the driver, one can use the `collect()` method to first bring the RDD to the driver node thus: `rdd.collect().foreach(println)`. This can cause the driver to run out of memory, though, because `collect()` fetches the entire RDD to a single machine; if you only need to print a few elements of the RDD, a safer approach is to use the `take()`: `rdd.take(100).foreach(println)`. - + ### Working with Key-Value Pairs
@@ -859,7 +859,7 @@ only available on RDDs of key-value pairs. The most common ones are distributed "shuffle" operations, such as grouping or aggregating the elements by a key. -In Java, key-value pairs are represented using the +In Java, key-value pairs are represented using the [scala.Tuple2](http://www.scala-lang.org/api/{{site.SCALA_VERSION}}/index.html#scala.Tuple2) class from the Scala standard library. You can simply call `new Tuple2(a, b)` to create a tuple, and access its fields later with `tuple._1()` and `tuple._2()`. @@ -974,7 +974,7 @@ for details. groupByKey([numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, Iterable<V>) pairs.
Note: If you are grouping in order to perform an aggregation (such as a sum or - average) over each key, using reduceByKey or aggregateByKey will yield much better + average) over each key, using reduceByKey or aggregateByKey will yield much better performance.
Note: By default, the level of parallelism in the output depends on the number of partitions of the parent RDD. @@ -1025,7 +1025,7 @@ for details. repartitionAndSortWithinPartitions(partitioner) Repartition the RDD according to the given partitioner and, within each resulting partition, - sort records by their keys. This is more efficient than calling repartition and then sorting within + sort records by their keys. This is more efficient than calling repartition and then sorting within each partition because it can push the sorting down into the shuffle machinery. @@ -1038,7 +1038,7 @@ RDD API doc [Java](api/java/index.html?org/apache/spark/api/java/JavaRDD.html), [Python](api/python/pyspark.html#pyspark.RDD), [R](api/R/index.html)) - + and pair RDD functions doc ([Scala](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions), [Java](api/java/index.html?org/apache/spark/api/java/JavaPairRDD.html)) @@ -1094,7 +1094,7 @@ for details. foreach(func) - Run a function func on each element of the dataset. This is usually done for side effects such as updating an Accumulator or interacting with external storage systems. + Run a function func on each element of the dataset. This is usually done for side effects such as updating an Accumulator or interacting with external storage systems.
Note: modifying variables other than Accumulators outside of the foreach() may result in undefined behavior. See Understanding closures for more details. @@ -1118,13 +1118,13 @@ co-located to compute the result. In Spark, data is generally not distributed across partitions to be in the necessary place for a specific operation. During computations, a single task will operate on a single partition - thus, to organize all the data for a single `reduceByKey` reduce task to execute, Spark needs to perform an -all-to-all operation. It must read from all partitions to find all the values for all keys, -and then bring together values across partitions to compute the final result for each key - +all-to-all operation. It must read from all partitions to find all the values for all keys, +and then bring together values across partitions to compute the final result for each key - this is called the **shuffle**. Although the set of elements in each partition of newly shuffled data will be deterministic, and so -is the ordering of partitions themselves, the ordering of these elements is not. If one desires predictably -ordered data following shuffle then it's possible to use: +is the ordering of partitions themselves, the ordering of these elements is not. If one desires predictably +ordered data following shuffle then it's possible to use: * `mapPartitions` to sort each partition using, for example, `.sorted` * `repartitionAndSortWithinPartitions` to efficiently sort partitions while simultaneously repartitioning @@ -1141,26 +1141,26 @@ network I/O. To organize data for the shuffle, Spark generates sets of tasks - * organize the data, and a set of *reduce* tasks to aggregate it. This nomenclature comes from MapReduce and does not directly relate to Spark's `map` and `reduce` operations. -Internally, results from individual map tasks are kept in memory until they can't fit. Then, these -are sorted based on the target partition and written to a single file. On the reduce side, tasks +Internally, results from individual map tasks are kept in memory until they can't fit. Then, these +are sorted based on the target partition and written to a single file. On the reduce side, tasks read the relevant sorted blocks. - -Certain shuffle operations can consume significant amounts of heap memory since they employ -in-memory data structures to organize records before or after transferring them. Specifically, -`reduceByKey` and `aggregateByKey` create these structures on the map side, and `'ByKey` operations -generate these on the reduce side. When data does not fit in memory Spark will spill these tables + +Certain shuffle operations can consume significant amounts of heap memory since they employ +in-memory data structures to organize records before or after transferring them. Specifically, +`reduceByKey` and `aggregateByKey` create these structures on the map side, and `'ByKey` operations +generate these on the reduce side. When data does not fit in memory Spark will spill these tables to disk, incurring the additional overhead of disk I/O and increased garbage collection. Shuffle also generates a large number of intermediate files on disk. As of Spark 1.3, these files -are preserved until the corresponding RDDs are no longer used and are garbage collected. -This is done so the shuffle files don't need to be re-created if the lineage is re-computed. -Garbage collection may happen only after a long period time, if the application retains references -to these RDDs or if GC does not kick in frequently. This means that long-running Spark jobs may +are preserved until the corresponding RDDs are no longer used and are garbage collected. +This is done so the shuffle files don't need to be re-created if the lineage is re-computed. +Garbage collection may happen only after a long period time, if the application retains references +to these RDDs or if GC does not kick in frequently. This means that long-running Spark jobs may consume a large amount of disk space. The temporary storage directory is specified by the `spark.local.dir` configuration parameter when configuring the Spark context. Shuffle behavior can be tuned by adjusting a variety of configuration parameters. See the -'Shuffle Behavior' section within the [Spark Configuration Guide](configuration.html). +'Shuffle Behavior' section within the [Spark Configuration Guide](configuration.html). ## RDD Persistence @@ -1246,7 +1246,7 @@ efficiency. We recommend going through the following process to select one: This is the most CPU-efficient option, allowing operations on the RDDs to run as fast as possible. * If not, try using `MEMORY_ONLY_SER` and [selecting a fast serialization library](tuning.html) to -make the objects much more space-efficient, but still reasonably fast to access. +make the objects much more space-efficient, but still reasonably fast to access. * Don't spill to disk unless the functions that computed your datasets are expensive, or they filter a large amount of the data. Otherwise, recomputing a partition may be as fast as reading it from @@ -1345,7 +1345,7 @@ Accumulators are variables that are only "added" to through an associative opera therefore be efficiently supported in parallel. They can be used to implement counters (as in MapReduce) or sums. Spark natively supports accumulators of numeric types, and programmers can add support for new types. If accumulators are created with a name, they will be -displayed in Spark's UI. This can be useful for understanding the progress of +displayed in Spark's UI. This can be useful for understanding the progress of running stages (NOTE: this is not yet supported in Python). An accumulator is created from an initial value `v` by calling `SparkContext.accumulator(v)`. Tasks @@ -1474,8 +1474,8 @@ vecAccum = sc.accumulator(Vector(...), VectorAccumulatorParam())
-For accumulator updates performed inside actions only, Spark guarantees that each task's update to the accumulator -will only be applied once, i.e. restarted tasks will not update the value. In transformations, users should be aware +For accumulator updates performed inside actions only, Spark guarantees that each task's update to the accumulator +will only be applied once, i.e. restarted tasks will not update the value. In transformations, users should be aware of that each task's update may be applied more than once if tasks or job stages are re-executed. Accumulators do not change the lazy evaluation model of Spark. If they are being updated within an operation on an RDD, their value is only updated once that RDD is computed as part of an action. Consequently, accumulator updates are not guaranteed to be executed when made within a lazy transformation like `map()`. The below code fragment demonstrates this property: @@ -1486,7 +1486,7 @@ Accumulators do not change the lazy evaluation model of Spark. If they are being {% highlight scala %} val accum = sc.accumulator(0) data.map { x => accum += x; f(x) } -// Here, accum is still 0 because no actions have caused the `map` to be computed. +// Here, accum is still 0 because no actions have caused the map to be computed. {% endhighlight %}
@@ -1553,7 +1553,7 @@ Several changes were made to the Java API: code that `extends Function` should `implement Function` instead. * New variants of the `map` transformations, like `mapToPair` and `mapToDouble`, were added to create RDDs of special data types. -* Grouping operations like `groupByKey`, `cogroup` and `join` have changed from returning +* Grouping operations like `groupByKey`, `cogroup` and `join` have changed from returning `(Key, List)` pairs to `(Key, Iterable)`.
diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 330c159c67bca..460a66f37dd64 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -245,7 +245,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.coarse false - If set to "true", runs over Mesos clusters in + If set to true, runs over Mesos clusters in "coarse-grained" sharing mode, where Spark acquires one long-lived Mesos task on each machine instead of one Mesos task per Spark task. This gives lower-latency scheduling for short queries, but leaves resources in use @@ -254,16 +254,16 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.extra.cores - 0 + 0 Set the extra amount of cpus to request per task. This setting is only used for Mesos coarse grain mode. The total amount of cores requested per task is the number of cores in the offer plus the extra cores configured. - Note that total amount of cores the executor will request in total will not exceed the spark.cores.max setting. + Note that total amount of cores the executor will request in total will not exceed the spark.cores.max setting. spark.mesos.mesosExecutor.cores - 1.0 + 1.0 (Fine-grained mode only) Number of cores to give each Mesos executor. This does not include the cores used to run the Spark tasks. In other words, even if no Spark task @@ -287,7 +287,7 @@ See the [configuration page](configuration.html) for information on Spark config Set the list of volumes which will be mounted into the Docker image, which was set using spark.mesos.executor.docker.image. The format of this property is a comma-separated list of - mappings following the form passed to docker run -v. That is they take the form: + mappings following the form passed to docker run -v. That is they take the form:
[host_path:]container_path[:ro|:rw]
@@ -318,7 +318,7 @@ See the [configuration page](configuration.html) for information on Spark config executor memory * 0.10, with minimum of 384 The amount of additional memory, specified in MB, to be allocated per executor. By default, - the overhead will be larger of either 384 or 10% of `spark.executor.memory`. If it's set, + the overhead will be larger of either 384 or 10% of spark.executor.memory. If set, the final overhead will be this value. @@ -339,7 +339,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.mesos.secret - (none)/td> + (none) Set the secret with which Spark framework will use to authenticate with Mesos. diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 3a961d245f3de..0e25ccf512c02 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -23,7 +23,7 @@ Unlike [Spark standalone](spark-standalone.html) and [Mesos](running-on-mesos.ht To launch a Spark application in `yarn-cluster` mode: $ ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options] - + For example: $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ @@ -43,7 +43,7 @@ To launch a Spark application in `yarn-client` mode, do the same, but replace `y ## Adding Other JARs -In `yarn-cluster` mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. +In `yarn-cluster` mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. $ ./bin/spark-submit --class my.main.Class \ --master yarn-cluster \ @@ -64,16 +64,16 @@ Most of the configs are the same for Spark on YARN as for other deployment modes # Debugging your Application -In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the `yarn.log-aggregation-enable` config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the "yarn logs" command. +In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the `yarn.log-aggregation-enable` config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the `yarn logs` command. yarn logs -applicationId - + will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). The logs are also available on the Spark Web UI under the Executors Tab. You need to have both the Spark history server and the MapReduce history server running and configure `yarn.log.server.url` in `yarn-site.xml` properly. The log URL on the Spark history server UI will redirect you to the MapReduce history server to show the aggregated logs. When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. The logs are also available on the Spark Web UI under the Executors Tab and doesn't require running the MapReduce history server. To review per-container launch environment, increase `yarn.nodemanager.delete.debug-delay-sec` to a -large value (e.g. 36000), and then access the application cache through `yarn.nodemanager.local-dirs` +large value (e.g. `36000`), and then access the application cache through `yarn.nodemanager.local-dirs` on the nodes on which containers are launched. This directory contains the launch script, JARs, and all environment variables used for launching each container. This process is useful for debugging classpath problems in particular. (Note that enabling this requires admin privileges on cluster @@ -92,7 +92,7 @@ Note that for the first option, both executors and the application master will s log4j configuration, which may cause issues when they run on the same node (e.g. trying to write to the same log file). -If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your log4j.properties. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming application, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log file, and logs can be accessed using YARN's log utility. +If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your `log4j.properties`. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming applications, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log files, and logs can be accessed using YARN's log utility. #### Spark Properties @@ -100,24 +100,26 @@ If you need a reference to the proper location to put log files in the YARN so t Property NameDefaultMeaning spark.yarn.am.memory - 512m + 512m Amount of memory to use for the YARN Application Master in client mode, in the same format as JVM memory strings (e.g. 512m, 2g). In cluster mode, use spark.driver.memory instead. +

+ Use lower-case suffixes, e.g. k, m, g, t, and p, for kibi-, mebi-, gibi-, tebi-, and pebibytes, respectively. spark.driver.cores - 1 + 1 Number of cores used by the driver in YARN cluster mode. - Since the driver is run in the same JVM as the YARN Application Master in cluster mode, this also controls the cores used by the YARN AM. - In client mode, use spark.yarn.am.cores to control the number of cores used by the YARN AM instead. + Since the driver is run in the same JVM as the YARN Application Master in cluster mode, this also controls the cores used by the YARN Application Master. + In client mode, use spark.yarn.am.cores to control the number of cores used by the YARN Application Master instead. spark.yarn.am.cores - 1 + 1 Number of cores to use for the YARN Application Master in client mode. In cluster mode, use spark.driver.cores instead. @@ -125,39 +127,39 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.am.waitTime - 100s + 100s - In `yarn-cluster` mode, time for the application master to wait for the - SparkContext to be initialized. In `yarn-client` mode, time for the application master to wait + In yarn-cluster mode, time for the YARN Application Master to wait for the + SparkContext to be initialized. In yarn-client mode, time for the YARN Application Master to wait for the driver to connect to it. spark.yarn.submit.file.replication - The default HDFS replication (usually 3) + The default HDFS replication (usually 3) HDFS replication level for the files uploaded into HDFS for the application. These include things like the Spark jar, the app jar, and any distributed cache files/archives. spark.yarn.preserve.staging.files - false + false - Set to true to preserve the staged files (Spark jar, app jar, distributed cache files) at the end of the job rather than delete them. + Set to true to preserve the staged files (Spark jar, app jar, distributed cache files) at the end of the job rather than delete them. spark.yarn.scheduler.heartbeat.interval-ms - 3000 + 3000 The interval in ms in which the Spark application master heartbeats into the YARN ResourceManager. - The value is capped at half the value of YARN's configuration for the expiry interval - (yarn.am.liveness-monitor.expiry-interval-ms). + The value is capped at half the value of YARN's configuration for the expiry interval, i.e. + yarn.am.liveness-monitor.expiry-interval-ms. spark.yarn.scheduler.initial-allocation.interval - 200ms + 200ms The initial interval in which the Spark application master eagerly heartbeats to the YARN ResourceManager when there are pending container allocation requests. It should be no larger than @@ -177,8 +179,8 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.historyServer.address (none) - The address of the Spark history server (i.e. host.com:18080). The address should not contain a scheme (http://). Defaults to not being set since the history server is an optional service. This address is given to the YARN ResourceManager when the Spark application finishes to link the application from the ResourceManager UI to the Spark history server UI. - For this property, YARN properties can be used as variables, and these are substituted by Spark at runtime. For eg, if the Spark history server runs on the same node as the YARN ResourceManager, it can be set to `${hadoopconf-yarn.resourcemanager.hostname}:18080`. + The address of the Spark history server, e.g. host.com:18080. The address should not contain a scheme (http://). Defaults to not being set since the history server is an optional service. This address is given to the YARN ResourceManager when the Spark application finishes to link the application from the ResourceManager UI to the Spark history server UI. + For this property, YARN properties can be used as variables, and these are substituted by Spark at runtime. For example, if the Spark history server runs on the same node as the YARN ResourceManager, it can be set to ${hadoopconf-yarn.resourcemanager.hostname}:18080. @@ -197,42 +199,42 @@ If you need a reference to the proper location to put log files in the YARN so t spark.executor.instances - 2 + 2 - The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. If both spark.dynamicAllocation.enabled and spark.executor.instances are specified, dynamic allocation is turned off and the specified number of spark.executor.instances is used. + The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. If both spark.dynamicAllocation.enabled and spark.executor.instances are specified, dynamic allocation is turned off and the specified number of spark.executor.instances is used. spark.yarn.executor.memoryOverhead executorMemory * 0.10, with minimum of 384 - The amount of off heap memory (in megabytes) to be allocated per executor. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the executor size (typically 6-10%). + The amount of off-heap memory (in megabytes) to be allocated per executor. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the executor size (typically 6-10%). spark.yarn.driver.memoryOverhead driverMemory * 0.10, with minimum of 384 - The amount of off heap memory (in megabytes) to be allocated per driver in cluster mode. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%). + The amount of off-heap memory (in megabytes) to be allocated per driver in cluster mode. This is memory that accounts for things like VM overheads, interned strings, other native overheads, etc. This tends to grow with the container size (typically 6-10%). spark.yarn.am.memoryOverhead AM memory * 0.10, with minimum of 384 - Same as spark.yarn.driver.memoryOverhead, but for the Application Master in client mode. + Same as spark.yarn.driver.memoryOverhead, but for the YARN Application Master in client mode. spark.yarn.am.port (random) - Port for the YARN Application Master to listen on. In YARN client mode, this is used to communicate between the Spark driver running on a gateway and the Application Master running on YARN. In YARN cluster mode, this is used for the dynamic executor feature, where it handles the kill from the scheduler backend. + Port for the YARN Application Master to listen on. In YARN client mode, this is used to communicate between the Spark driver running on a gateway and the YARN Application Master running on YARN. In YARN cluster mode, this is used for the dynamic executor feature, where it handles the kill from the scheduler backend. spark.yarn.queue - default + default The name of the YARN queue to which the application is submitted. @@ -245,18 +247,18 @@ If you need a reference to the proper location to put log files in the YARN so t By default, Spark on YARN will use a Spark jar installed locally, but the Spark jar can also be in a world-readable location on HDFS. This allows YARN to cache it on nodes so that it doesn't need to be distributed each time an application runs. To point to a jar on HDFS, for example, - set this configuration to "hdfs:///some/path". + set this configuration to hdfs:///some/path. spark.yarn.access.namenodes (none) - A list of secure HDFS namenodes your Spark application is going to access. For - example, `spark.yarn.access.namenodes=hdfs://nn1.com:8032,hdfs://nn2.com:8032`. - The Spark application must have acess to the namenodes listed and Kerberos must - be properly configured to be able to access them (either in the same realm or in - a trusted realm). Spark acquires security tokens for each of the namenodes so that + A comma-separated list of secure HDFS namenodes your Spark application is going to access. For + example, spark.yarn.access.namenodes=hdfs://nn1.com:8032,hdfs://nn2.com:8032. + The Spark application must have access to the namenodes listed and Kerberos must + be properly configured to be able to access them (either in the same realm or in + a trusted realm). Spark acquires security tokens for each of the namenodes so that the Spark application can access those remote HDFS clusters. @@ -264,18 +266,18 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.appMasterEnv.[EnvironmentVariableName] (none) - Add the environment variable specified by EnvironmentVariableName to the - Application Master process launched on YARN. The user can specify multiple of - these and to set multiple environment variables. In `yarn-cluster` mode this controls - the environment of the SPARK driver and in `yarn-client` mode it only controls - the environment of the executor launcher. + Add the environment variable specified by EnvironmentVariableName to the + Application Master process launched on YARN. The user can specify multiple of + these and to set multiple environment variables. In yarn-cluster mode this controls + the environment of the Spark driver and in yarn-client mode it only controls + the environment of the executor launcher. spark.yarn.containerLauncherMaxThreads - 25 + 25 - The maximum number of threads to use in the application master for launching executor containers. + The maximum number of threads to use in the YARN Application Master for launching executor containers. @@ -283,19 +285,19 @@ If you need a reference to the proper location to put log files in the YARN so t (none) A string of extra JVM options to pass to the YARN Application Master in client mode. - In cluster mode, use `spark.driver.extraJavaOptions` instead. + In cluster mode, use spark.driver.extraJavaOptions instead. spark.yarn.am.extraLibraryPath (none) - Set a special library path to use when launching the application master in client mode. + Set a special library path to use when launching the YARN Application Master in client mode. spark.yarn.maxAppAttempts - yarn.resourcemanager.am.max-attempts in YARN + yarn.resourcemanager.am.max-attempts in YARN The maximum number of attempts that will be made to submit the application. It should be no larger than the global number of max attempts in the YARN configuration. @@ -303,10 +305,10 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.submit.waitAppCompletion - true + true In YARN cluster mode, controls whether the client waits to exit until the application completes. - If set to true, the client process will stay alive reporting the application's status. + If set to true, the client process will stay alive reporting the application's status. Otherwise, the client process will exit after submission. @@ -332,7 +334,7 @@ If you need a reference to the proper location to put log files in the YARN so t (none) The full path to the file that contains the keytab for the principal specified above. - This keytab will be copied to the node running the Application Master via the Secure Distributed Cache, + This keytab will be copied to the node running the YARN Application Master via the Secure Distributed Cache, for renewing the login tickets and the delegation tokens periodically. @@ -371,14 +373,14 @@ If you need a reference to the proper location to put log files in the YARN so t spark.yarn.security.tokens.${service}.enabled - true + true Controls whether to retrieve delegation tokens for non-HDFS services when security is enabled. By default, delegation tokens for all supported services are retrieved when those services are configured, but it's possible to disable that behavior if it somehow conflicts with the application being run.

- Currently supported services are: hive, hbase + Currently supported services are: hive, hbase @@ -387,5 +389,5 @@ If you need a reference to the proper location to put log files in the YARN so t - Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. - In `yarn-cluster` mode, the local directories used by the Spark executors and the Spark driver will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. In `yarn-client` mode, the Spark executors will use the local directories configured for YARN while the Spark driver will use those defined in `spark.local.dir`. This is because the Spark driver does not run on the YARN cluster in `yarn-client` mode, only the Spark executors do. -- The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. +- The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named `localtest.txt` into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. - The `--jars` option allows the `SparkContext.addJar` function to work if you are using it with local files and running in `yarn-cluster` mode. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 7ae9244c271e3..a1cbc7de97c65 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1676,7 +1676,7 @@ results <- collect(sql(sqlContext, "FROM src SELECT key, value")) ### Interacting with Different Versions of Hive Metastore One of the most important pieces of Spark SQL's Hive support is interaction with Hive metastore, -which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary +which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary build of Spark SQL can be used to query different versions of Hive metastores, using the configuration described below. Note that independent of the version of Hive that is being used to talk to the metastore, internally Spark SQL will compile against Hive 1.2.1 and use those classes for internal execution (serdes, UDFs, UDAFs, etc). @@ -1706,8 +1706,8 @@ The following options can be used to configure the version of Hive that is used either 1.2.1 or not defined.

  • maven
  • Use Hive jars of specified version downloaded from Maven repositories. This configuration - is not generally recommended for production deployments. -
  • A classpath in the standard format for the JVM. This classpath must include all of Hive + is not generally recommended for production deployments. +
  • A classpath in the standard format for the JVM. This classpath must include all of Hive and its dependencies, including the correct version of Hadoop. These jars only need to be present on the driver, but if you are running in yarn cluster mode then you must ensure they are packaged with you application.
  • @@ -1806,7 +1806,7 @@ the Data Sources API. The following options are supported:
    {% highlight scala %} -val jdbcDF = sqlContext.read.format("jdbc").options( +val jdbcDF = sqlContext.read.format("jdbc").options( Map("url" -> "jdbc:postgresql:dbserver", "dbtable" -> "schema.tablename")).load() {% endhighlight %} @@ -2023,11 +2023,11 @@ options. - Optimized execution using manually managed memory (Tungsten) is now enabled by default, along with code generation for expression evaluation. These features can both be disabled by setting - `spark.sql.tungsten.enabled` to `false. - - Parquet schema merging is no longer enabled by default. It can be re-enabled by setting + `spark.sql.tungsten.enabled` to `false`. + - Parquet schema merging is no longer enabled by default. It can be re-enabled by setting `spark.sql.parquet.mergeSchema` to `true`. - - Resolution of strings to columns in python now supports using dots (`.`) to qualify the column or - access nested values. For example `df['table.column.nestedField']`. However, this means that if + - Resolution of strings to columns in python now supports using dots (`.`) to qualify the column or + access nested values. For example `df['table.column.nestedField']`. However, this means that if your column name contains any dots you must now escape them using backticks (e.g., ``table.`column.with.dots`.nested``). - In-memory columnar storage partition pruning is on by default. It can be disabled by setting `spark.sql.inMemoryColumnarStorage.partitionPruning` to `false`. diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 7ea4d6f1a3f8f..915be0f479157 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -103,7 +103,7 @@ run it with `--help`. Here are a few examples of common options: export HADOOP_CONF_DIR=XXX ./bin/spark-submit \ --class org.apache.spark.examples.SparkPi \ - --master yarn-cluster \ # can also be `yarn-client` for client mode + --master yarn-cluster \ # can also be yarn-client for client mode --executor-memory 20G \ --num-executors 50 \ /path/to/examples.jar \ @@ -174,9 +174,9 @@ This can use up a significant amount of space over time and will need to be clea is handled automatically, and with Spark standalone, automatic cleanup can be configured with the `spark.worker.cleanup.appDataTtl` property. -Users may also include any other dependencies by supplying a comma-delimited list of maven coordinates -with `--packages`. All transitive dependencies will be handled when using this command. Additional -repositories (or resolvers in SBT) can be added in a comma-delimited fashion with the flag `--repositories`. +Users may also include any other dependencies by supplying a comma-delimited list of maven coordinates +with `--packages`. All transitive dependencies will be handled when using this command. Additional +repositories (or resolvers in SBT) can be added in a comma-delimited fashion with the flag `--repositories`. These commands can be used with `pyspark`, `spark-shell`, and `spark-submit` to include Spark Packages. For Python, the equivalent `--py-files` option can be used to distribute `.egg`, `.zip` and `.py` libraries From 331f0b10f78a37d96d3e573d211d74a0935265db Mon Sep 17 00:00:00 2001 From: Meihua Wu Date: Mon, 21 Sep 2015 12:09:00 -0700 Subject: [PATCH 361/802] [SPARK-9642] [ML] LinearRegression should supported weighted data In many modeling application, data points are not necessarily sampled with equal probabilities. Linear regression should support weighting which account the over or under sampling. work in progress. Author: Meihua Wu Closes #8631 from rotationsymmetry/SPARK-9642. --- .../ml/regression/LinearRegression.scala | 164 +++++++++++------- .../ml/regression/LinearRegressionSuite.scala | 88 ++++++++++ project/MimaExcludes.scala | 8 +- 3 files changed, 191 insertions(+), 69 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index e4602d36ccc87..78a67c5fdab20 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -31,21 +31,29 @@ import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.functions.{col, udf} -import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.functions.{col, udf, lit} import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.StatCounter /** * Params for linear regression. */ private[regression] trait LinearRegressionParams extends PredictorParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol - with HasFitIntercept with HasStandardization + with HasFitIntercept with HasStandardization with HasWeightCol + +/** + * Class that represents an instance of weighted data point with label and features. + * + * TODO: Refactor this class to proper place. + * + * @param label Label for this data point. + * @param weight The weight of this instance. + * @param features The vector of features for this data point. + */ +private[regression] case class Instance(label: Double, weight: Double, features: Vector) /** * :: Experimental :: @@ -123,30 +131,43 @@ class LinearRegression(override val uid: String) def setTol(value: Double): this.type = set(tol, value) setDefault(tol -> 1E-6) + /** + * Whether to over-/under-sample training instances according to the given weights in weightCol. + * If empty, all instances are treated equally (weight 1.0). + * Default is empty, so all instances have weight one. + * @group setParam + */ + def setWeightCol(value: String): this.type = set(weightCol, value) + setDefault(weightCol -> "") + override protected def train(dataset: DataFrame): LinearRegressionModel = { // Extract columns from data. If dataset is persisted, do not persist instances. - val instances = extractLabeledPoints(dataset).map { - case LabeledPoint(label: Double, features: Vector) => (label, features) + val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) } + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val (summarizer, statCounter) = instances.treeAggregate( - (new MultivariateOnlineSummarizer, new StatCounter))( - seqOp = (c, v) => (c, v) match { - case ((summarizer: MultivariateOnlineSummarizer, statCounter: StatCounter), - (label: Double, features: Vector)) => - (summarizer.add(features), statCounter.merge(label)) - }, - combOp = (c1, c2) => (c1, c2) match { - case ((summarizer1: MultivariateOnlineSummarizer, statCounter1: StatCounter), - (summarizer2: MultivariateOnlineSummarizer, statCounter2: StatCounter)) => - (summarizer1.merge(summarizer2), statCounter1.merge(statCounter2)) - }) - - val numFeatures = summarizer.mean.size - val yMean = statCounter.mean - val yStd = math.sqrt(statCounter.variance) + val (featuresSummarizer, ySummarizer) = { + val seqOp = (c: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer), + instance: Instance) => + (c._1.add(instance.features, instance.weight), + c._2.add(Vectors.dense(instance.label), instance.weight)) + + val combOp = (c1: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer), + c2: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer)) => + (c1._1.merge(c2._1), c1._2.merge(c2._2)) + + instances.treeAggregate( + new MultivariateOnlineSummarizer, new MultivariateOnlineSummarizer)(seqOp, combOp) + } + + val numFeatures = featuresSummarizer.mean.size + val yMean = ySummarizer.mean(0) + val yStd = math.sqrt(ySummarizer.variance(0)) // If the yStd is zero, then the intercept is yMean with zero weights; // as a result, training is not needed. @@ -167,8 +188,8 @@ class LinearRegression(override val uid: String) return copyValues(model.setSummary(trainingSummary)) } - val featuresMean = summarizer.mean.toArray - val featuresStd = summarizer.variance.toArray.map(math.sqrt) + val featuresMean = featuresSummarizer.mean.toArray + val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) // Since we implicitly do the feature scaling when we compute the cost function // to improve the convergence, the effective regParam will be changed. @@ -318,7 +339,8 @@ class LinearRegressionModel private[ml] ( /** * :: Experimental :: - * Linear regression training results. + * Linear regression training results. Currently, the training summary ignores the + * training weights except for the objective trace. * @param predictions predictions outputted by the model's `transform` method. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ @@ -477,7 +499,7 @@ class LinearRegressionSummary private[regression] ( * \frac{\partial L}{\partial\w_i} = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) * }}}, * - * @param weights The weights/coefficients corresponding to the features. + * @param coefficients The coefficients corresponding to the features. * @param labelStd The standard deviation value of the label. * @param labelMean The mean value of the label. * @param fitIntercept Whether to fit an intercept term. @@ -485,7 +507,7 @@ class LinearRegressionSummary private[regression] ( * @param featuresMean The mean values of the features. */ private class LeastSquaresAggregator( - weights: Vector, + coefficients: Vector, labelStd: Double, labelMean: Double, fitIntercept: Boolean, @@ -493,26 +515,28 @@ private class LeastSquaresAggregator( featuresMean: Array[Double]) extends Serializable { private var totalCnt: Long = 0L + private var weightSum: Double = 0.0 private var lossSum = 0.0 - private val (effectiveWeightsArray: Array[Double], offset: Double, dim: Int) = { - val weightsArray = weights.toArray.clone() + private val (effectiveCoefficientsArray: Array[Double], offset: Double, dim: Int) = { + val coefficientsArray = coefficients.toArray.clone() var sum = 0.0 var i = 0 - val len = weightsArray.length + val len = coefficientsArray.length while (i < len) { if (featuresStd(i) != 0.0) { - weightsArray(i) /= featuresStd(i) - sum += weightsArray(i) * featuresMean(i) + coefficientsArray(i) /= featuresStd(i) + sum += coefficientsArray(i) * featuresMean(i) } else { - weightsArray(i) = 0.0 + coefficientsArray(i) = 0.0 } i += 1 } - (weightsArray, if (fitIntercept) labelMean / labelStd - sum else 0.0, weightsArray.length) + val offset = if (fitIntercept) labelMean / labelStd - sum else 0.0 + (coefficientsArray, offset, coefficientsArray.length) } - private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray) + private val effectiveCoefficientsVector = Vectors.dense(effectiveCoefficientsArray) private val gradientSumArray = Array.ofDim[Double](dim) @@ -520,30 +544,33 @@ private class LeastSquaresAggregator( * Add a new training data to this LeastSquaresAggregator, and update the loss and gradient * of the objective function. * - * @param label The label for this data point. - * @param data The features for one data point in dense/sparse vector format to be added - * into this aggregator. + * @param instance The data point instance to be added. * @return This LeastSquaresAggregator object. */ - def add(label: Double, data: Vector): this.type = { - require(dim == data.size, s"Dimensions mismatch when adding new sample." + - s" Expecting $dim but got ${data.size}.") + def add(instance: Instance): this.type = + instance match { case Instance(label, weight, features) => + require(dim == features.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $dim but got ${features.size}.") + require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") - val diff = dot(data, effectiveWeightsVector) - label / labelStd + offset + if (weight == 0.0) return this - if (diff != 0) { - val localGradientSumArray = gradientSumArray - data.foreachActive { (index, value) => - if (featuresStd(index) != 0.0 && value != 0.0) { - localGradientSumArray(index) += diff * value / featuresStd(index) + val diff = dot(features, effectiveCoefficientsVector) - label / labelStd + offset + + if (diff != 0) { + val localGradientSumArray = gradientSumArray + features.foreachActive { (index, value) => + if (featuresStd(index) != 0.0 && value != 0.0) { + localGradientSumArray(index) += weight * diff * value / featuresStd(index) + } } + lossSum += weight * diff * diff / 2.0 } - lossSum += diff * diff / 2.0 - } - totalCnt += 1 - this - } + totalCnt += 1 + weightSum += weight + this + } /** * Merge another LeastSquaresAggregator, and update the loss and gradient @@ -557,8 +584,9 @@ private class LeastSquaresAggregator( require(dim == other.dim, s"Dimensions mismatch when merging with another " + s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.") - if (other.totalCnt != 0) { + if (other.weightSum != 0) { totalCnt += other.totalCnt + weightSum += other.weightSum lossSum += other.lossSum var i = 0 @@ -574,11 +602,17 @@ private class LeastSquaresAggregator( def count: Long = totalCnt - def loss: Double = lossSum / totalCnt + def loss: Double = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but $weightSum.") + lossSum / weightSum + } def gradient: Vector = { + require(weightSum > 0.0, s"The effective number of instances should be " + + s"greater than 0.0, but $weightSum.") val result = Vectors.dense(gradientSumArray.clone()) - scal(1.0 / totalCnt, result) + scal(1.0 / weightSum, result) result } } @@ -589,7 +623,7 @@ private class LeastSquaresAggregator( * It's used in Breeze's convex optimization routines. */ private class LeastSquaresCostFun( - data: RDD[(Double, Vector)], + data: RDD[Instance], labelStd: Double, labelMean: Double, fitIntercept: Boolean, @@ -598,17 +632,13 @@ private class LeastSquaresCostFun( featuresMean: Array[Double], effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] { - override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = { - val w = Vectors.fromBreeze(weights) + override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { + val coeff = Vectors.fromBreeze(coefficients) - val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd, + val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(coeff, labelStd, labelMean, fitIntercept, featuresStd, featuresMean))( - seqOp = (c, v) => (c, v) match { - case (aggregator, (label, features)) => aggregator.add(label, features) - }, - combOp = (c1, c2) => (c1, c2) match { - case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) - }) + seqOp = (aggregator, instance) => aggregator.add(instance), + combOp = (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) val totalGradientArray = leastSquaresAggregator.gradient.toArray @@ -616,7 +646,7 @@ private class LeastSquaresCostFun( 0.0 } else { var sum = 0.0 - w.foreachActive { (index, value) => + coeff.foreachActive { (index, value) => // The following code will compute the loss of the regularization; also // the gradient of the regularization, and add back to totalGradientArray. sum += { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 2aaee71ecc734..8428f4f00b370 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.ml.regression +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.linalg.{DenseVector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} @@ -510,4 +513,89 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .zip(testSummary.residuals.select("residuals").collect()) .forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 } } + + test("linear regression with weighted samples"){ + val (data, weightedData) = { + val activeData = LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1) + + val rnd = new Random(8392) + val signedData = activeData.map { case p: LabeledPoint => + (rnd.nextGaussian() > 0.0, p) + } + + val data1 = signedData.flatMap { + case (true, p) => Iterator(p, p) + case (false, p) => Iterator(p) + } + + val weightedSignedData = signedData.flatMap { + case (true, LabeledPoint(label, features)) => + Iterator( + Instance(label, weight = 1.2, features), + Instance(label, weight = 0.8, features) + ) + case (false, LabeledPoint(label, features)) => + Iterator( + Instance(label, weight = 0.3, features), + Instance(label, weight = 0.1, features), + Instance(label, weight = 0.6, features) + ) + } + + val noiseData = LinearDataGenerator.generateLinearInput( + 2, Array(1, 3), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1) + val weightedNoiseData = noiseData.map { + case LabeledPoint(label, features) => Instance(label, weight = 0, features) + } + val data2 = weightedSignedData ++ weightedNoiseData + + (sqlContext.createDataFrame(sc.parallelize(data1, 4)), + sqlContext.createDataFrame(sc.parallelize(data2, 4))) + } + + val trainer1a = (new LinearRegression).setFitIntercept(true) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val trainer1b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight") + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val model1a0 = trainer1a.fit(data) + val model1a1 = trainer1a.fit(weightedData) + val model1b = trainer1b.fit(weightedData) + assert(model1a0.weights !~= model1a1.weights absTol 1E-3) + assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3) + assert(model1a0.weights ~== model1b.weights absTol 1E-3) + assert(model1a0.intercept ~== model1b.intercept absTol 1E-3) + + val trainer2a = (new LinearRegression).setFitIntercept(true) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + val trainer2b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight") + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + val model2a0 = trainer2a.fit(data) + val model2a1 = trainer2a.fit(weightedData) + val model2b = trainer2b.fit(weightedData) + assert(model2a0.weights !~= model2a1.weights absTol 1E-3) + assert(model2a0.intercept !~= model2a1.intercept absTol 1E-3) + assert(model2a0.weights ~== model2b.weights absTol 1E-3) + assert(model2a0.intercept ~== model2b.intercept absTol 1E-3) + + val trainer3a = (new LinearRegression).setFitIntercept(false) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val trainer3b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight") + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true) + val model3a0 = trainer3a.fit(data) + val model3a1 = trainer3a.fit(weightedData) + val model3b = trainer3b.fit(weightedData) + assert(model3a0.weights !~= model3a1.weights absTol 1E-3) + assert(model3a0.weights ~== model3b.weights absTol 1E-3) + + val trainer4a = (new LinearRegression).setFitIntercept(false) + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + val trainer4b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight") + .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false) + val model4a0 = trainer4a.fit(data) + val model4a1 = trainer4a.fit(weightedData) + val model4b = trainer4b.fit(weightedData) + assert(model4a0.weights !~= model4a1.weights absTol 1E-3) + assert(model4a0.weights ~== model4b.weights absTol 1E-3) + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 814a11e588ceb..b2e6be706637b 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -70,10 +70,14 @@ object MimaExcludes { "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"), ProblemFilters.exclude[IncompatibleMethTypeProblem]( "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply") - ) ++ - Seq( + ) ++ Seq( ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.shuffle.FileShuffleBlockResolver$ShuffleFileGroup") + ) ++ Seq( + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresAggregator.add"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.ml.regression.LeastSquaresCostFun.this") ) case v if v.startsWith("1.5") => Seq( From b78c65b03ae87a3ba348c9d29ff4c296349eb49c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?hushan=5B=E8=83=A1=E7=8F=8A=5D?= Date: Mon, 21 Sep 2015 14:26:15 -0500 Subject: [PATCH 362/802] [SPARK-5259] [CORE] don't submit stage until its dependencies map outputs are registered MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Track pending tasks by partition ID instead of Task objects. Before this change, failure & retry could result in a case where a stage got submitted before the map output from its dependencies get registered. This was due to an error in the condition for registering map outputs. Author: hushan[胡珊] Author: Imran Rashid Closes #7699 from squito/SPARK-5259. --- .../apache/spark/scheduler/DAGScheduler.scala | 12 +- .../org/apache/spark/scheduler/Stage.scala | 2 +- .../spark/scheduler/TaskSetManager.scala | 4 +- .../spark/scheduler/DAGSchedulerSuite.scala | 197 ++++++++++++++++-- 4 files changed, 191 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 3c9a66e504403..394228b2728d1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -944,7 +944,7 @@ class DAGScheduler( private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") // Get our pending tasks and remember them in our pendingTasks entry - stage.pendingTasks.clear() + stage.pendingPartitions.clear() // First figure out the indexes of partition ids to compute. val (allPartitions: Seq[Int], partitionsToCompute: Seq[Int]) = { @@ -1060,8 +1060,8 @@ class DAGScheduler( if (tasks.size > 0) { logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") - stage.pendingTasks ++= tasks - logDebug("New pending tasks: " + stage.pendingTasks) + stage.pendingPartitions ++= tasks.map(_.partitionId) + logDebug("New pending partitions: " + stage.pendingPartitions) taskScheduler.submitTasks(new TaskSet( tasks.toArray, stage.id, stage.latestInfo.attemptId, stage.firstJobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) @@ -1152,7 +1152,7 @@ class DAGScheduler( case Success => listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType, event.reason, event.taskInfo, event.taskMetrics)) - stage.pendingTasks -= task + stage.pendingPartitions -= task.partitionId task match { case rt: ResultTask[_, _] => // Cast to ResultStage here because it's part of the ResultTask @@ -1198,7 +1198,7 @@ class DAGScheduler( shuffleStage.addOutputLoc(smt.partitionId, status) } - if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) { + if (runningStages.contains(shuffleStage) && shuffleStage.pendingPartitions.isEmpty) { markStageAsFinished(shuffleStage) logInfo("looking for newly runnable stages") logInfo("running: " + runningStages) @@ -1242,7 +1242,7 @@ class DAGScheduler( case Resubmitted => logInfo("Resubmitted " + task + ", so marking it as still running") - stage.pendingTasks += task + stage.pendingPartitions += task.partitionId case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) => val failedStage = stageIdToStage(task.stageId) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index b37eccbd0f7b8..a3829c319c48d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -66,7 +66,7 @@ private[scheduler] abstract class Stage( /** Set of jobs that this stage belongs to. */ val jobIds = new HashSet[Int] - var pendingTasks = new HashSet[Task[_]] + val pendingPartitions = new HashSet[Int] /** The ID to use for the next new attempt for this stage. */ private var nextAttemptId: Int = 0 diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 62af9031b9f8b..c02597c4365c9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -487,8 +487,8 @@ private[spark] class TaskSetManager( // a good proxy to task serialization time. // val timeTaken = clock.getTime() - startTime val taskName = s"task ${info.id} in stage ${taskSet.id}" - logInfo("Starting %s (TID %d, %s, %s, %d bytes)".format( - taskName, taskId, host, taskLocality, serializedTask.limit)) + logInfo(s"Starting $taskName (TID $taskId, $host, partition ${task.partitionId}," + + s"$taskLocality, ${serializedTask.limit} bytes)") sched.dagScheduler.taskStarted(task, info) return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId, diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 1c55f90ad9b44..6b5bcf0574de6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -479,8 +479,8 @@ class DAGSchedulerSuite val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), - (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) // the 2nd ResultTask failed complete(taskSets(1), Seq( (Success, 42), @@ -490,7 +490,7 @@ class DAGSchedulerSuite // ask the scheduler to try it again scheduler.resubmitFailedStages() // have the 2nd attempt pass - complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.size)))) + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.length)))) // we can see both result blocks now assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === HashSet("hostA", "hostB")) @@ -782,8 +782,8 @@ class DAGSchedulerSuite val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), - (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) // The MapOutputTracker should know about both map output locations. assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === HashSet("hostA", "hostB")) @@ -1035,6 +1035,173 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + /** + * This test runs a three stage job, with a fetch failure in stage 1. but during the retry, we + * have completions from both the first & second attempt of stage 1. So all the map output is + * available before we finish any task set for stage 1. We want to make sure that we don't + * submit stage 2 until the map output for stage 1 is registered + */ + test("don't submit stage until its dependencies map outputs are registered (SPARK-5259)") { + val firstRDD = new MyRDD(sc, 3, Nil) + val firstShuffleDep = new ShuffleDependency(firstRDD, null) + val firstShuffleId = firstShuffleDep.shuffleId + val shuffleMapRdd = new MyRDD(sc, 3, List(firstShuffleDep)) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + submit(reduceRdd, Array(0)) + + // things start out smoothly, stage 0 completes with no issues + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostB", shuffleMapRdd.partitions.length)), + (Success, makeMapStatus("hostB", shuffleMapRdd.partitions.length)), + (Success, makeMapStatus("hostA", shuffleMapRdd.partitions.length)) + )) + + // then one executor dies, and a task fails in stage 1 + runEvent(ExecutorLost("exec-hostA")) + runEvent(CompletionEvent( + taskSets(1).tasks(0), + FetchFailed(null, firstShuffleId, 2, 0, "Fetch failed"), + null, + null, + createFakeTaskInfo(), + null)) + + // so we resubmit stage 0, which completes happily + scheduler.resubmitFailedStages() + val stage0Resubmit = taskSets(2) + assert(stage0Resubmit.stageId == 0) + assert(stage0Resubmit.stageAttemptId === 1) + val task = stage0Resubmit.tasks(0) + assert(task.partitionId === 2) + runEvent(CompletionEvent( + task, + Success, + makeMapStatus("hostC", shuffleMapRdd.partitions.length), + null, + createFakeTaskInfo(), + null)) + + // now here is where things get tricky : we will now have a task set representing + // the second attempt for stage 1, but we *also* have some tasks for the first attempt for + // stage 1 still going + val stage1Resubmit = taskSets(3) + assert(stage1Resubmit.stageId == 1) + assert(stage1Resubmit.stageAttemptId === 1) + assert(stage1Resubmit.tasks.length === 3) + + // we'll have some tasks finish from the first attempt, and some finish from the second attempt, + // so that we actually have all stage outputs, though no attempt has completed all its + // tasks + runEvent(CompletionEvent( + taskSets(3).tasks(0), + Success, + makeMapStatus("hostC", reduceRdd.partitions.length), + null, + createFakeTaskInfo(), + null)) + runEvent(CompletionEvent( + taskSets(3).tasks(1), + Success, + makeMapStatus("hostC", reduceRdd.partitions.length), + null, + createFakeTaskInfo(), + null)) + // late task finish from the first attempt + runEvent(CompletionEvent( + taskSets(1).tasks(2), + Success, + makeMapStatus("hostB", reduceRdd.partitions.length), + null, + createFakeTaskInfo(), + null)) + + // What should happen now is that we submit stage 2. However, we might not see an error + // b/c of DAGScheduler's error handling (it tends to swallow errors and just log them). But + // we can check some conditions. + // Note that the really important thing here is not so much that we submit stage 2 *immediately* + // but that we don't end up with some error from these interleaved completions. It would also + // be OK (though sub-optimal) if stage 2 simply waited until the resubmission of stage 1 had + // all its tasks complete + + // check that we have all the map output for stage 0 (it should have been there even before + // the last round of completions from stage 1, but just to double check it hasn't been messed + // up) and also the newly available stage 1 + val stageToReduceIdxs = Seq( + 0 -> (0 until 3), + 1 -> (0 until 1) + ) + for { + (stage, reduceIdxs) <- stageToReduceIdxs + reduceIdx <- reduceIdxs + } { + // this would throw an exception if the map status hadn't been registered + val statuses = mapOutputTracker.getMapSizesByExecutorId(stage, reduceIdx) + // really we should have already thrown an exception rather than fail either of these + // asserts, but just to be extra defensive let's double check the statuses are OK + assert(statuses != null) + assert(statuses.nonEmpty) + } + + // and check that stage 2 has been submitted + assert(taskSets.size == 5) + val stage2TaskSet = taskSets(4) + assert(stage2TaskSet.stageId == 2) + assert(stage2TaskSet.stageAttemptId == 0) + } + + /** + * We lose an executor after completing some shuffle map tasks on it. Those tasks get + * resubmitted, and when they finish the job completes normally + */ + test("register map outputs correctly after ExecutorLost and task Resubmitted") { + val firstRDD = new MyRDD(sc, 3, Nil) + val firstShuffleDep = new ShuffleDependency(firstRDD, null) + val reduceRdd = new MyRDD(sc, 5, List(firstShuffleDep)) + submit(reduceRdd, Array(0)) + + // complete some of the tasks from the first stage, on one host + runEvent(CompletionEvent( + taskSets(0).tasks(0), Success, + makeMapStatus("hostA", reduceRdd.partitions.length), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSets(0).tasks(1), Success, + makeMapStatus("hostA", reduceRdd.partitions.length), null, createFakeTaskInfo(), null)) + + // now that host goes down + runEvent(ExecutorLost("exec-hostA")) + + // so we resubmit those tasks + runEvent(CompletionEvent( + taskSets(0).tasks(0), Resubmitted, null, null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent( + taskSets(0).tasks(1), Resubmitted, null, null, createFakeTaskInfo(), null)) + + // now complete everything on a different host + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)) + )) + + // now we should submit stage 1, and the map output from stage 0 should be registered + + // check that we have all the map output for stage 0 + (0 until reduceRdd.partitions.length).foreach { reduceIdx => + val statuses = mapOutputTracker.getMapSizesByExecutorId(0, reduceIdx) + // really we should have already thrown an exception rather than fail either of these + // asserts, but just to be extra defensive let's double check the statuses are OK + assert(statuses != null) + assert(statuses.nonEmpty) + } + + // and check that stage 1 has been submitted + assert(taskSets.size == 2) + val stage1TaskSet = taskSets(1) + assert(stage1TaskSet.stageId == 1) + assert(stage1TaskSet.stageAttemptId == 0) + } + /** * Makes sure that failures of stage used by multiple jobs are correctly handled. * @@ -1393,8 +1560,8 @@ class DAGSchedulerSuite // Submit a map stage by itself submitMapStage(shuffleDep) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), - (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) assert(results.size === 1) results.clear() assertDataStructuresEmpty() @@ -1407,7 +1574,7 @@ class DAGSchedulerSuite // Ask the scheduler to try it again; TaskSet 2 will rerun the map task that we couldn't fetch // from, then TaskSet 3 will run the reduce stage scheduler.resubmitFailedStages() - complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.size)))) + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.length)))) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) results.clear() @@ -1452,8 +1619,8 @@ class DAGSchedulerSuite // Complete the first stage assert(taskSets(0).stageId === 0) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", rdd1.partitions.size)), - (Success, makeMapStatus("hostB", rdd1.partitions.size)))) + (Success, makeMapStatus("hostA", rdd1.partitions.length)), + (Success, makeMapStatus("hostB", rdd1.partitions.length)))) assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) assert(listener1.results.size === 1) @@ -1461,7 +1628,7 @@ class DAGSchedulerSuite // When attempting the second stage, show a fetch failure assert(taskSets(1).stageId === 1) complete(taskSets(1), Seq( - (Success, makeMapStatus("hostA", rdd2.partitions.size)), + (Success, makeMapStatus("hostA", rdd2.partitions.length)), (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0, 0, "ignored"), null))) scheduler.resubmitFailedStages() assert(listener2.results.size === 0) // Second stage listener should not have a result yet @@ -1469,7 +1636,7 @@ class DAGSchedulerSuite // Stage 0 should now be running as task set 2; make its task succeed assert(taskSets(2).stageId === 0) complete(taskSets(2), Seq( - (Success, makeMapStatus("hostC", rdd2.partitions.size)))) + (Success, makeMapStatus("hostC", rdd2.partitions.length)))) assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) assert(listener2.results.size === 0) // Second stage listener should still not have a result @@ -1477,8 +1644,8 @@ class DAGSchedulerSuite // Stage 1 should now be running as task set 3; make its first task succeed assert(taskSets(3).stageId === 1) complete(taskSets(3), Seq( - (Success, makeMapStatus("hostB", rdd2.partitions.size)), - (Success, makeMapStatus("hostD", rdd2.partitions.size)))) + (Success, makeMapStatus("hostB", rdd2.partitions.length)), + (Success, makeMapStatus("hostD", rdd2.partitions.length)))) assert(mapOutputTracker.getMapSizesByExecutorId(dep2.shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostD"))) assert(listener2.results.size === 1) @@ -1494,7 +1661,7 @@ class DAGSchedulerSuite // TaskSet 5 will rerun stage 1's lost task, then TaskSet 6 will rerun stage 2 assert(taskSets(5).stageId === 1) complete(taskSets(5), Seq( - (Success, makeMapStatus("hostE", rdd2.partitions.size)))) + (Success, makeMapStatus("hostE", rdd2.partitions.length)))) complete(taskSets(6), Seq( (Success, 53))) assert(listener3.results === Map(0 -> 52, 1 -> 53)) From ba882db6f43dd2bc05675133158e4664ed07030a Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 21 Sep 2015 13:06:23 -0700 Subject: [PATCH 363/802] [SPARK-9769] [ML] [PY] add python api for countvectorizermodel From JIRA: Add Python API, user guide and example for ml.feature.CountVectorizerModel Author: Holden Karau Closes #8561 from holdenk/SPARK-9769-add-python-api-for-countvectorizermodel. --- python/pyspark/ml/feature.py | 148 +++++++++++++++++++++++++++++++++-- 1 file changed, 142 insertions(+), 6 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 92db8df80280b..f41d72f877256 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -26,12 +26,13 @@ from pyspark.mllib.common import inherit_doc from pyspark.mllib.linalg import _convert_to_vector -__all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', - 'IndexToString', 'MinMaxScaler', 'MinMaxScalerModel', 'NGram', 'Normalizer', - 'OneHotEncoder', 'PCA', 'PCAModel', 'PolynomialExpansion', 'RegexTokenizer', - 'RFormula', 'RFormulaModel', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', - 'StopWordsRemover', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', - 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', 'Word2Vec', 'Word2VecModel'] +__all__ = ['Binarizer', 'Bucketizer', 'CountVectorizer', 'CountVectorizerModel', 'DCT', + 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', 'IndexToString', 'MinMaxScaler', + 'MinMaxScalerModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PCA', 'PCAModel', + 'PolynomialExpansion', 'RegexTokenizer', 'RFormula', 'RFormulaModel', 'SQLTransformer', + 'StandardScaler', 'StandardScalerModel', 'StopWordsRemover', 'StringIndexer', + 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'VectorSlicer', + 'Word2Vec', 'Word2VecModel'] @inherit_doc @@ -171,6 +172,141 @@ def getSplits(self): return self.getOrDefault(self.splits) +@inherit_doc +class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + Extracts a vocabulary from document collections and generates a :py:attr:`CountVectorizerModel`. + >>> df = sqlContext.createDataFrame( + ... [(0, ["a", "b", "c"]), (1, ["a", "b", "b", "c", "a"])], + ... ["label", "raw"]) + >>> cv = CountVectorizer(inputCol="raw", outputCol="vectors") + >>> model = cv.fit(df) + >>> model.transform(df).show(truncate=False) + +-----+---------------+-------------------------+ + |label|raw |vectors | + +-----+---------------+-------------------------+ + |0 |[a, b, c] |(3,[0,1,2],[1.0,1.0,1.0])| + |1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])| + +-----+---------------+-------------------------+ + ... + >>> sorted(map(str, model.vocabulary)) + ['a', 'b', 'c'] + """ + + # a placeholder to make it appear in the generated doc + minTF = Param( + Params._dummy(), "minTF", "Filter to ignore rare words in" + + " a document. For each document, terms with frequency/count less than the given" + + " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + + " times the term must appear in the document); if this is a double in [0,1), then this " + + "specifies a fraction (out of the document's token count). Note that the parameter is " + + "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0") + minDF = Param( + Params._dummy(), "minDF", "Specifies the minimum number of" + + " different documents a term must appear in to be included in the vocabulary." + + " If this is an integer >= 1, this specifies the number of documents the term must" + + " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + + " Default 1.0") + vocabSize = Param( + Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.") + + @keyword_only + def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None): + """ + __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None) + """ + super(CountVectorizer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", + self.uid) + self.minTF = Param( + self, "minTF", "Filter to ignore rare words in" + + " a document. For each document, terms with frequency/count less than the given" + + " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + + " times the term must appear in the document); if this is a double in [0,1), then " + + "this specifies a fraction (out of the document's token count). Note that the " + + "parameter is only used in transform of CountVectorizerModel and does not affect" + + "fitting. Default 1.0") + self.minDF = Param( + self, "minDF", "Specifies the minimum number of" + + " different documents a term must appear in to be included in the vocabulary." + + " If this is an integer >= 1, this specifies the number of documents the term must" + + " appear in; if this is a double in [0,1), then this specifies the fraction of " + + "documents. Default 1.0") + self.vocabSize = Param( + self, "vocabSize", "max size of the vocabulary. Default 1 << 18.") + self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None): + """ + setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None) + Set the params for the CountVectorizer + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setMinTF(self, value): + """ + Sets the value of :py:attr:`minTF`. + """ + self._paramMap[self.minTF] = value + return self + + def getMinTF(self): + """ + Gets the value of minTF or its default value. + """ + return self.getOrDefault(self.minTF) + + def setMinDF(self, value): + """ + Sets the value of :py:attr:`minDF`. + """ + self._paramMap[self.minDF] = value + return self + + def getMinDF(self): + """ + Gets the value of minDF or its default value. + """ + return self.getOrDefault(self.minDF) + + def setVocabSize(self, value): + """ + Sets the value of :py:attr:`vocabSize`. + """ + self._paramMap[self.vocabSize] = value + return self + + def getVocabSize(self): + """ + Gets the value of vocabSize or its default value. + """ + return self.getOrDefault(self.vocabSize) + + def _create_model(self, java_model): + return CountVectorizerModel(java_model) + + +class CountVectorizerModel(JavaModel): + """ + .. note:: Experimental + + Model fitted by CountVectorizer. + """ + + @property + def vocabulary(self): + """ + An array of terms in the vocabulary. + """ + return self._call_java("vocabulary") + + @inherit_doc class DCT(JavaTransformer, HasInputCol, HasOutputCol): """ From aeef44a3e32b53f7adecc8e9cfd684fb4598e87d Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Mon, 21 Sep 2015 13:11:28 -0700 Subject: [PATCH 364/802] [SPARK-3147] [MLLIB] [STREAMING] Streaming 2-sample statistical significance testing Implementation of significance testing using Streaming API. Author: Feynman Liang Author: Feynman Liang Closes #4716 from feynmanliang/ab_testing. --- .../examples/mllib/StreamingTestExample.scala | 90 +++++++ .../spark/mllib/stat/test/StreamingTest.scala | 145 +++++++++++ .../mllib/stat/test/StreamingTestMethod.scala | 167 ++++++++++++ .../spark/mllib/stat/test/TestResult.scala | 22 ++ .../spark/mllib/stat/StreamingTestSuite.scala | 243 ++++++++++++++++++ 5 files changed, 667 insertions(+) create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala new file mode 100644 index 0000000000000..ab29f90254d34 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.SparkConf +import org.apache.spark.mllib.stat.test.StreamingTest +import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.spark.util.Utils + +/** + * Perform streaming testing using Welch's 2-sample t-test on a stream of data, where the data + * stream arrives as text files in a directory. Stops when the two groups are statistically + * significant (p-value < 0.05) or after a user-specified timeout in number of batches is exceeded. + * + * The rows of the text files must be in the form `Boolean, Double`. For example: + * false, -3.92 + * true, 99.32 + * + * Usage: + * StreamingTestExample + * + * To run on your local machine using the directory `dataDir` with 5 seconds between each batch and + * a timeout after 100 insignificant batches, call: + * $ bin/run-example mllib.StreamingTestExample dataDir 5 100 + * + * As you add text files to `dataDir` the significance test wil continually update every + * `batchDuration` seconds until the test becomes significant (p-value < 0.05) or the number of + * batches processed exceeds `numBatchesTimeout`. + */ +object StreamingTestExample { + + def main(args: Array[String]) { + if (args.length != 3) { + // scalastyle:off println + System.err.println( + "Usage: StreamingTestExample " + + " ") + // scalastyle:on println + System.exit(1) + } + val dataDir = args(0) + val batchDuration = Seconds(args(1).toLong) + val numBatchesTimeout = args(2).toInt + + val conf = new SparkConf().setMaster("local").setAppName("StreamingTestExample") + val ssc = new StreamingContext(conf, batchDuration) + ssc.checkpoint({ + val dir = Utils.createTempDir() + dir.toString + }) + + val data = ssc.textFileStream(dataDir).map(line => line.split(",") match { + case Array(label, value) => (label.toBoolean, value.toDouble) + }) + + val streamingTest = new StreamingTest() + .setPeacePeriod(0) + .setWindowSize(0) + .setTestMethod("welch") + + val out = streamingTest.registerStream(data) + out.print() + + // Stop processing if test becomes significant or we time out + var timeoutCounter = numBatchesTimeout + out.foreachRDD { rdd => + timeoutCounter -= 1 + val anySignificant = rdd.map(_.pValue < 0.05).fold(false)(_ || _) + if (timeoutCounter == 0 || anySignificant) rdd.context.stop() + } + + ssc.start() + ssc.awaitTermination() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala new file mode 100644 index 0000000000000..75c6a51d09571 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.test + +import org.apache.spark.Logging +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.StatCounter + +/** + * :: Experimental :: + * Performs online 2-sample significance testing for a stream of (Boolean, Double) pairs. The + * Boolean identifies which sample each observation comes from, and the Double is the numeric value + * of the observation. + * + * To address novelty affects, the `peacePeriod` specifies a set number of initial + * [[org.apache.spark.rdd.RDD]] batches of the [[DStream]] to be dropped from significance testing. + * + * The `windowSize` sets the number of batches each significance test is to be performed over. The + * window is sliding with a stride length of 1 batch. Setting windowSize to 0 will perform + * cumulative processing, using all batches seen so far. + * + * Different tests may be used for assessing statistical significance depending on assumptions + * satisfied by data. For more details, see [[StreamingTestMethod]]. The `testMethod` specifies + * which test will be used. + * + * Use a builder pattern to construct a streaming test in an application, for example: + * {{{ + * val model = new StreamingTest() + * .setPeacePeriod(10) + * .setWindowSize(0) + * .setTestMethod("welch") + * .registerStream(DStream) + * }}} + */ +@Experimental +@Since("1.6.0") +class StreamingTest @Since("1.6.0") () extends Logging with Serializable { + private var peacePeriod: Int = 0 + private var windowSize: Int = 0 + private var testMethod: StreamingTestMethod = WelchTTest + + /** Set the number of initial batches to ignore. Default: 0. */ + @Since("1.6.0") + def setPeacePeriod(peacePeriod: Int): this.type = { + this.peacePeriod = peacePeriod + this + } + + /** + * Set the number of batches to compute significance tests over. Default: 0. + * A value of 0 will use all batches seen so far. + */ + @Since("1.6.0") + def setWindowSize(windowSize: Int): this.type = { + this.windowSize = windowSize + this + } + + /** Set the statistical method used for significance testing. Default: "welch" */ + @Since("1.6.0") + def setTestMethod(method: String): this.type = { + this.testMethod = StreamingTestMethod.getTestMethodFromName(method) + this + } + + /** + * Register a [[DStream]] of values for significance testing. + * + * @param data stream of (key,value) pairs where the key denotes group membership (true = + * experiment, false = control) and the value is the numerical metric to test for + * significance + * @return stream of significance testing results + */ + @Since("1.6.0") + def registerStream(data: DStream[(Boolean, Double)]): DStream[StreamingTestResult] = { + val dataAfterPeacePeriod = dropPeacePeriod(data) + val summarizedData = summarizeByKeyAndWindow(dataAfterPeacePeriod) + val pairedSummaries = pairSummaries(summarizedData) + + testMethod.doTest(pairedSummaries) + } + + /** Drop all batches inside the peace period. */ + private[stat] def dropPeacePeriod( + data: DStream[(Boolean, Double)]): DStream[(Boolean, Double)] = { + data.transform { (rdd, time) => + if (time.milliseconds > data.slideDuration.milliseconds * peacePeriod) { + rdd + } else { + data.context.sparkContext.parallelize(Seq()) + } + } + } + + /** Compute summary statistics over each key and the specified test window size. */ + private[stat] def summarizeByKeyAndWindow( + data: DStream[(Boolean, Double)]): DStream[(Boolean, StatCounter)] = { + if (this.windowSize == 0) { + data.updateStateByKey[StatCounter]( + (newValues: Seq[Double], oldSummary: Option[StatCounter]) => { + val newSummary = oldSummary.getOrElse(new StatCounter()) + newSummary.merge(newValues) + Some(newSummary) + }) + } else { + val windowDuration = data.slideDuration * this.windowSize + data + .groupByKeyAndWindow(windowDuration) + .mapValues { values => + val summary = new StatCounter() + values.foreach(value => summary.merge(value)) + summary + } + } + } + + /** + * Transform a stream of summaries into pairs representing summary statistics for control group + * and experiment group up to this batch. + */ + private[stat] def pairSummaries(summarizedData: DStream[(Boolean, StatCounter)]) + : DStream[(StatCounter, StatCounter)] = { + summarizedData + .map[(Int, StatCounter)](x => (0, x._2)) + .groupByKey() // should be length two (control/experiment group) + .map(x => (x._2.head, x._2.last)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala new file mode 100644 index 0000000000000..a7eaed51b4d55 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTestMethod.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.test + +import java.io.Serializable + +import scala.language.implicitConversions +import scala.math.pow + +import com.twitter.chill.MeatLocker +import org.apache.commons.math3.stat.descriptive.StatisticalSummaryValues +import org.apache.commons.math3.stat.inference.TTest + +import org.apache.spark.Logging +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.StatCounter + +/** + * Significance testing methods for [[StreamingTest]]. New 2-sample statistical significance tests + * should extend [[StreamingTestMethod]] and introduce a new entry in + * [[StreamingTestMethod.TEST_NAME_TO_OBJECT]] + */ +private[stat] sealed trait StreamingTestMethod extends Serializable { + + val methodName: String + val nullHypothesis: String + + protected type SummaryPairStream = + DStream[(StatCounter, StatCounter)] + + /** + * Perform streaming 2-sample statistical significance testing. + * + * @param sampleSummaries stream pairs of summary statistics for the 2 samples + * @return stream of rest results + */ + def doTest(sampleSummaries: SummaryPairStream): DStream[StreamingTestResult] + + /** + * Implicit adapter to convert between streaming summary statistics type and the type required by + * the t-testing libraries. + */ + protected implicit def toApacheCommonsStats( + summaryStats: StatCounter): StatisticalSummaryValues = { + new StatisticalSummaryValues( + summaryStats.mean, + summaryStats.variance, + summaryStats.count, + summaryStats.max, + summaryStats.min, + summaryStats.mean * summaryStats.count + ) + } +} + +/** + * Performs Welch's 2-sample t-test. The null hypothesis is that the two data sets have equal mean. + * This test does not assume equal variance between the two samples and does not assume equal + * sample size. + * + * @see http://en.wikipedia.org/wiki/Welch%27s_t_test + */ +private[stat] object WelchTTest extends StreamingTestMethod with Logging { + + override final val methodName = "Welch's 2-sample t-test" + override final val nullHypothesis = "Both groups have same mean" + + private final val tTester = MeatLocker(new TTest()) + + override def doTest(data: SummaryPairStream): DStream[StreamingTestResult] = + data.map[StreamingTestResult]((test _).tupled) + + private def test( + statsA: StatCounter, + statsB: StatCounter): StreamingTestResult = { + def welchDF(sample1: StatisticalSummaryValues, sample2: StatisticalSummaryValues): Double = { + val s1 = sample1.getVariance + val n1 = sample1.getN + val s2 = sample2.getVariance + val n2 = sample2.getN + + val a = pow(s1, 2) / n1 + val b = pow(s2, 2) / n2 + + pow(a + b, 2) / ((pow(a, 2) / (n1 - 1)) + (pow(b, 2) / (n2 - 1))) + } + + new StreamingTestResult( + tTester.get.tTest(statsA, statsB), + welchDF(statsA, statsB), + tTester.get.t(statsA, statsB), + methodName, + nullHypothesis + ) + } +} + +/** + * Performs Students's 2-sample t-test. The null hypothesis is that the two data sets have equal + * mean. This test assumes equal variance between the two samples and does not assume equal sample + * size. For unequal variances, Welch's t-test should be used instead. + * + * @see http://en.wikipedia.org/wiki/Student%27s_t-test + */ +private[stat] object StudentTTest extends StreamingTestMethod with Logging { + + override final val methodName = "Student's 2-sample t-test" + override final val nullHypothesis = "Both groups have same mean" + + private final val tTester = MeatLocker(new TTest()) + + override def doTest(data: SummaryPairStream): DStream[StreamingTestResult] = + data.map[StreamingTestResult]((test _).tupled) + + private def test( + statsA: StatCounter, + statsB: StatCounter): StreamingTestResult = { + def studentDF(sample1: StatisticalSummaryValues, sample2: StatisticalSummaryValues): Double = + sample1.getN + sample2.getN - 2 + + new StreamingTestResult( + tTester.get.homoscedasticTTest(statsA, statsB), + studentDF(statsA, statsB), + tTester.get.homoscedasticT(statsA, statsB), + methodName, + nullHypothesis + ) + } +} + +/** + * Companion object holding supported [[StreamingTestMethod]] names and handles conversion between + * strings used in [[StreamingTest]] configuration and actual method implementation. + * + * Currently supported tests: `welch`, `student`. + */ +private[stat] object StreamingTestMethod { + // Note: after new `StreamingTestMethod`s are implemented, please update this map. + private final val TEST_NAME_TO_OBJECT: Map[String, StreamingTestMethod] = Map( + "welch"->WelchTTest, + "student"->StudentTTest) + + def getTestMethodFromName(method: String): StreamingTestMethod = + TEST_NAME_TO_OBJECT.get(method) match { + case Some(test) => test + case None => + throw new IllegalArgumentException( + "Unrecognized method name. Supported streaming test methods: " + + TEST_NAME_TO_OBJECT.keys.mkString(", ")) + } +} + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala index d01b3707be944..b0916d3e84651 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala @@ -115,3 +115,25 @@ class KolmogorovSmirnovTestResult private[stat] ( "Kolmogorov-Smirnov test summary:\n" + super.toString } } + +/** + * :: Experimental :: + * Object containing the test results for streaming testing. + */ +@Experimental +@Since("1.6.0") +private[stat] class StreamingTestResult @Since("1.6.0") ( + @Since("1.6.0") override val pValue: Double, + @Since("1.6.0") override val degreesOfFreedom: Double, + @Since("1.6.0") override val statistic: Double, + @Since("1.6.0") val method: String, + @Since("1.6.0") override val nullHypothesis: String) + extends TestResult[Double] with Serializable { + + override def toString: String = { + "Streaming test summary:\n" + + s"method: $method\n" + + super.toString + } +} + diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala new file mode 100644 index 0000000000000..d3e9ef4ff079c --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.stat.test.{StreamingTest, StreamingTestResult, StudentTTest, WelchTTest} +import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.StatCounter +import org.apache.spark.util.random.XORShiftRandom + +class StreamingTestSuite extends SparkFunSuite with TestSuiteBase { + + override def maxWaitTimeMillis : Int = 30000 + + test("accuracy for null hypothesis using welch t-test") { + // set parameters + val testMethod = "welch" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = 0 + val stdevA = 0.001 + val meanB = 0 + val stdevB = 0.001 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(res => + res.pValue > 0.05 && res.method == WelchTTest.methodName)) + } + + test("accuracy for alternative hypothesis using welch t-test") { + // set parameters + val testMethod = "welch" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(res => + res.pValue < 0.05 && res.method == WelchTTest.methodName)) + } + + test("accuracy for null hypothesis using student t-test") { + // set parameters + val testMethod = "student" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = 0 + val stdevA = 0.001 + val meanB = 0 + val stdevB = 0.001 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + + assert(outputBatches.flatten.forall(res => + res.pValue > 0.05 && res.method == StudentTTest.methodName)) + } + + test("accuracy for alternative hypothesis using student t-test") { + // set parameters + val testMethod = "student" + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + .setTestMethod(testMethod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(res => + res.pValue < 0.05 && res.method == StudentTTest.methodName)) + } + + test("batches within same test window are grouped") { + // set parameters + val testWindow = 3 + val numBatches = 5 + val pointsPerBatch = 100 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new StreamingTest() + .setWindowSize(testWindow) + .setPeacePeriod(0) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, + (inputDStream: DStream[(Boolean, Double)]) => model.summarizeByKeyAndWindow(inputDStream)) + val outputBatches = runStreams[(Boolean, StatCounter)](ssc, numBatches, numBatches) + val outputCounts = outputBatches.flatten.map(_._2.count) + + // number of batches seen so far does not exceed testWindow, expect counts to continue growing + for (i <- 0 until testWindow) { + assert(outputCounts.drop(2 * i).take(2).forall(_ == (i + 1) * pointsPerBatch / 2)) + } + + // number of batches seen exceeds testWindow, expect counts to be constant + assert(outputCounts.drop(2 * (testWindow - 1)).forall(_ == testWindow * pointsPerBatch / 2)) + } + + + test("entries in peace period are dropped") { + // set parameters + val peacePeriod = 3 + val numBatches = 7 + val pointsPerBatch = 1000 + val meanA = -10 + val stdevA = 1 + val meanB = 10 + val stdevB = 1 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(peacePeriod) + + val input = generateTestData( + numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.dropPeacePeriod(inputDStream)) + val outputBatches = runStreams[(Boolean, Double)](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.length == (numBatches - peacePeriod) * pointsPerBatch) + } + + test("null hypothesis when only data from one group is present") { + // set parameters + val numBatches = 2 + val pointsPerBatch = 1000 + val meanA = 0 + val stdevA = 0.001 + val meanB = 0 + val stdevB = 0.001 + + val model = new StreamingTest() + .setWindowSize(0) + .setPeacePeriod(0) + + val input = generateTestData(numBatches, pointsPerBatch, meanA, stdevA, meanB, stdevB, 42) + .map(batch => batch.filter(_._1)) // only keep one test group + + // setup and run the model + val ssc = setupStreams( + input, (inputDStream: DStream[(Boolean, Double)]) => model.registerStream(inputDStream)) + val outputBatches = runStreams[StreamingTestResult](ssc, numBatches, numBatches) + + assert(outputBatches.flatten.forall(result => (result.pValue - 1.0).abs < 0.001)) + } + + // Generate testing input with half of the entries in group A and half in group B + private def generateTestData( + numBatches: Int, + pointsPerBatch: Int, + meanA: Double, + stdevA: Double, + meanB: Double, + stdevB: Double, + seed: Int): (IndexedSeq[IndexedSeq[(Boolean, Double)]]) = { + val rand = new XORShiftRandom(seed) + val numTrues = pointsPerBatch / 2 + val data = (0 until numBatches).map { i => + (0 until numTrues).map { idx => (true, meanA + stdevA * rand.nextGaussian())} ++ + (pointsPerBatch / 2 until pointsPerBatch).map { idx => + (false, meanB + stdevB * rand.nextGaussian()) + } + } + + data + } +} From 97a99dde6e8d69a4c4c135dc1d9b1520b2548b5b Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 21 Sep 2015 13:15:44 -0700 Subject: [PATCH 365/802] [SPARK-10676] [DOCS] Add documentation for SASL encryption options. Author: Marcelo Vanzin Closes #8803 from vanzin/SPARK-10676. --- docs/configuration.md | 16 ++++++++++++++++ docs/security.md | 22 ++++++++++++++++++++-- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index b22587c70316b..284f97ad09ec3 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1285,6 +1285,22 @@ Apart from these, the following properties are also available, and may be useful not running on YARN and authentication is enabled. + + spark.authenticate.enableSaslEncryption + false + + Enable encrypted communication when authentication is enabled. This option is currently + only supported by the block transfer service. + + + + spark.network.sasl.serverAlwaysEncrypt + false + + Disable unencrypted connections for services that support SASL authentication. This is + currently supported by the external shuffle service. + + spark.core.connection.ack.wait.timeout 60s diff --git a/docs/security.md b/docs/security.md index d4ffa60e59a33..177109415180b 100644 --- a/docs/security.md +++ b/docs/security.md @@ -23,9 +23,16 @@ If your applications are using event logging, the directory where the event logs ## Encryption -Spark supports SSL for Akka and HTTP (for broadcast and file server) protocols. However SSL is not supported yet for WebUI and block transfer service. +Spark supports SSL for Akka and HTTP (for broadcast and file server) protocols. SASL encryption is +supported for the block transfer service. Encryption is not yet supported for the WebUI. -Connection encryption (SSL) configuration is organized hierarchically. The user can configure the default SSL settings which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the protocols without disabling the ability to configure each one individually. The common SSL settings are at `spark.ssl` namespace in Spark configuration, while Akka SSL configuration is at `spark.ssl.akka` and HTTP for broadcast and file server SSL configuration is at `spark.ssl.fs`. The full breakdown can be found on the [configuration page](configuration.html). +Encryption is not yet supported for data stored by Spark in temporary local storage, such as shuffle +files, cached data, and other application files. If encrypting this data is desired, a workaround is +to configure your cluster manager to store application data on encrypted disks. + +### SSL Configuration + +Configuration for SSL is organized hierarchically. The user can configure the default SSL settings which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the protocols without disabling the ability to configure each one individually. The common SSL settings are at `spark.ssl` namespace in Spark configuration, while Akka SSL configuration is at `spark.ssl.akka` and HTTP for broadcast and file server SSL configuration is at `spark.ssl.fs`. The full breakdown can be found on the [configuration page](configuration.html). SSL must be configured on each node and configured for each component involved in communication using the particular protocol. @@ -47,6 +54,17 @@ follows: * Import all exported public keys into a single trust-store * Distribute the trust-store over the nodes +### Configuring SASL Encryption + +SASL encryption is currently supported for the block transfer service when authentication +(`spark.authenticate`) is enabled. To enable SASL encryption for an application, set +`spark.authenticate.enableSaslEncryption` to `true` in the application's configuration. + +When using an external shuffle service, it's possible to disable unencrypted connections by setting +`spark.network.sasl.serverAlwaysEncrypt` to `true` in the shuffle service's configuration. If that +option is enabled, applications that are not set up to use SASL encryption will fail to connect to +the shuffle service. + ## Configuring Ports for Network Security Spark makes heavy use of the network, and some environments have strict requirements for using tight From 362539f8d97f6bb67f0d0983f7dea36b77cc9d18 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 21 Sep 2015 13:33:10 -0700 Subject: [PATCH 366/802] [SPARK-10630] [SQL] Add a createDataFrame API that takes in a java list It would be nice to support creating a DataFrame directly from a Java List of Row. Author: Holden Karau Closes #8779 from holdenk/SPARK-10630-create-DataFrame-from-Java-List. --- .../scala/org/apache/spark/sql/SQLContext.scala | 14 ++++++++++++++ .../org/apache/spark/sql/JavaDataFrameSuite.java | 10 ++++++++++ 2 files changed, 24 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index f099940800cc0..1bd4e26fb3162 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -476,6 +476,20 @@ class SQLContext(@transient val sparkContext: SparkContext) createDataFrame(rowRDD.rdd, schema) } + /** + * :: DeveloperApi :: + * Creates a [[DataFrame]] from an [[java.util.List]] containing [[Row]]s using the given schema. + * It is important to make sure that the structure of every [[Row]] of the provided List matches + * the provided schema. Otherwise, there will be runtime exception. + * + * @group dataframes + * @since 1.6.0 + */ + @DeveloperApi + def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = { + DataFrame(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala)) + } + /** * Applies a schema to an RDD of Java Beans. * diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 5f9abd4999ce0..250ac2e1092d4 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -37,6 +37,7 @@ import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; @@ -181,6 +182,15 @@ public void testCreateDataFrameFromJavaBeans() { } } + @Test + public void testCreateDataFromFromList() { + StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true))); + List rows = Arrays.asList(RowFactory.create(0)); + DataFrame df = context.createDataFrame(rows, schema); + Row[] result = df.collect(); + Assert.assertEquals(1, result.length); + } + private static final Comparator crosstabRowComparator = new Comparator() { @Override public int compare(Row row1, Row row2) { From 7c4f852bfc39537840f56cd8121457a0dc1ad7c1 Mon Sep 17 00:00:00 2001 From: noelsmith Date: Mon, 21 Sep 2015 14:24:19 -0700 Subject: [PATCH 367/802] [DOC] [PYSPARK] [MLLIB] Added newlines to docstrings to fix parameter formatting Added newlines before `:param ...:` and `:return:` markup. Without these, parameter lists aren't formatted correctly in the API docs. I.e: ![screen shot 2015-09-21 at 21 49 26](https://cloud.githubusercontent.com/assets/11915197/10004686/de3c41d4-60aa-11e5-9c50-a46dcb51243f.png) .. looks like this once newline is added: ![screen shot 2015-09-21 at 21 50 14](https://cloud.githubusercontent.com/assets/11915197/10004706/f86bfb08-60aa-11e5-8524-ae4436713502.png) Author: noelsmith Closes #8851 from noel-smith/docstring-missing-newline-fix. --- python/pyspark/ml/param/__init__.py | 4 ++++ python/pyspark/ml/pipeline.py | 1 + python/pyspark/ml/tuning.py | 2 ++ python/pyspark/ml/wrapper.py | 2 ++ python/pyspark/mllib/evaluation.py | 2 +- python/pyspark/mllib/linalg/__init__.py | 1 + python/pyspark/streaming/context.py | 2 ++ python/pyspark/streaming/mqtt.py | 1 + 8 files changed, 14 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index eeeac49b21980..2e0c63cb47b17 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -164,6 +164,7 @@ def extractParamMap(self, extra=None): a flat param map, where the latter value is used if there exist conflicts, i.e., with ordering: default param values < user-supplied values < extra. + :param extra: extra param values :return: merged param map """ @@ -182,6 +183,7 @@ def copy(self, extra=None): embedded and extra parameters over and returns the copy. Subclasses should override this method if the default approach is not sufficient. + :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ @@ -201,6 +203,7 @@ def _shouldOwn(self, param): def _resolveParam(self, param): """ Resolves a param and validates the ownership. + :param param: param name or the param instance, which must belong to this Params instance :return: resolved param instance @@ -243,6 +246,7 @@ def _copyValues(self, to, extra=None): """ Copies param values from this instance to another instance for params shared by them. + :param to: the target instance :param extra: extra params to be copied :return: the target instance with param values copied diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 13cf2b0f7bbd9..312a8502b3a2c 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -154,6 +154,7 @@ def __init__(self, stages=None): def setStages(self, value): """ Set pipeline stages. + :param value: a list of transformers or estimators :return: the pipeline instance """ diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index ab5621f45c72c..705ee53685752 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -254,6 +254,7 @@ def copy(self, extra=None): Creates a copy of this instance with a randomly generated uid and some extra params. This copies creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. + :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ @@ -290,6 +291,7 @@ def copy(self, extra=None): and some extra params. This copies the underlying bestModel, creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. + :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 8218c7c5f801c..4bcb4aaec89de 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -119,6 +119,7 @@ def _create_model(self, java_model): def _fit_java(self, dataset): """ Fits a Java model to the input dataset. + :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame` :param params: additional params (overwriting embedded values) @@ -173,6 +174,7 @@ def copy(self, extra=None): extra params. This implementation first calls Params.copy and then make a copy of the companion Java model with extra params. So both the Python wrapper and the Java model get copied. + :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 4398ca86f2ec2..a90e5c50e54b9 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -147,7 +147,7 @@ class MulticlassMetrics(JavaModelWrapper): """ Evaluator for multiclass classification. - :param predictionAndLabels an RDD of (prediction, label) pairs. + :param predictionAndLabels: an RDD of (prediction, label) pairs. >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]) diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index f929e3e96fbe2..ea42127f1651f 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -240,6 +240,7 @@ class Vector(object): def toArray(self): """ Convert the vector into an numpy.ndarray + :return: numpy.ndarray """ raise NotImplementedError diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 4069d7a149986..a8c9ffc235b9e 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -240,6 +240,7 @@ def start(self): def awaitTermination(self, timeout=None): """ Wait for the execution to stop. + @param timeout: time to wait in seconds """ if timeout is None: @@ -252,6 +253,7 @@ def awaitTerminationOrTimeout(self, timeout): Wait for the execution to stop. Return `true` if it's stopped; or throw the reported error during the execution; or `false` if the waiting time elapsed before returning from the method. + @param timeout: time to wait in seconds """ self._jssc.awaitTerminationOrTimeout(int(timeout * 1000)) diff --git a/python/pyspark/streaming/mqtt.py b/python/pyspark/streaming/mqtt.py index f06598971c548..fa83006c36db6 100644 --- a/python/pyspark/streaming/mqtt.py +++ b/python/pyspark/streaming/mqtt.py @@ -31,6 +31,7 @@ def createStream(ssc, brokerUrl, topic, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): """ Create an input stream that pulls messages from a Mqtt Broker. + :param ssc: StreamingContext object :param brokerUrl: Url of remote mqtt publisher :param topic: topic name to subscribe to From 72869883f12b6e0a4e5aad79c0ac2cfdb4d83f09 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 21 Sep 2015 16:47:52 -0700 Subject: [PATCH 368/802] [SPARK-10649] [STREAMING] Prevent inheriting job group and irrelevant job description in streaming jobs The job group, and job descriptions information is passed through thread local properties, and get inherited by child threads. In case of spark streaming, the streaming jobs inherit these properties from the thread that called streamingContext.start(). This may not make sense. 1. Job group: This is mainly used for cancelling a group of jobs together. It does not make sense to cancel streaming jobs like this, as the effect will be unpredictable. And its not a valid usecase any way, to cancel a streaming context, call streamingContext.stop() 2. Job description: This is used to pass on nice text descriptions for jobs to show up in the UI. The job description of the thread that calls streamingContext.start() is not useful for all the streaming jobs, as it does not make sense for all of the streaming jobs to have the same description, and the description may or may not be related to streaming. The solution in this PR is meant for the Spark master branch, where local properties are inherited by cloning the properties. The job group and job description in the thread that starts the streaming scheduler are explicitly removed, so that all the subsequent child threads does not inherit them. Also, the starting is done in a new child thread, so that setting the job group and description for streaming, does not change those properties in the thread that called streamingContext.start(). Author: Tathagata Das Closes #8781 from tdas/SPARK-10649. --- .../org/apache/spark/util/ThreadUtils.scala | 59 +++++++++++++++++++ .../apache/spark/util/ThreadUtilsSuite.scala | 24 +++++++- .../spark/streaming/StreamingContext.scala | 15 ++++- .../streaming/StreamingContextSuite.scala | 32 ++++++++++ 4 files changed, 126 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index ca5624a3d8b3d..22e291a2b48d6 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -21,6 +21,7 @@ package org.apache.spark.util import java.util.concurrent._ import scala.concurrent.{ExecutionContext, ExecutionContextExecutor} +import scala.util.control.NonFatal import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} @@ -86,4 +87,62 @@ private[spark] object ThreadUtils { val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build() Executors.newSingleThreadScheduledExecutor(threadFactory) } + + /** + * Run a piece of code in a new thread and return the result. Exception in the new thread is + * thrown in the caller thread with an adjusted stack trace that removes references to this + * method for clarity. The exception stack traces will be like the following + * + * SomeException: exception-message + * at CallerClass.body-method (sourcefile.scala) + * at ... run in separate thread using org.apache.spark.util.ThreadUtils ... () + * at CallerClass.caller-method (sourcefile.scala) + * ... + */ + def runInNewThread[T]( + threadName: String, + isDaemon: Boolean = true)(body: => T): T = { + @volatile var exception: Option[Throwable] = None + @volatile var result: T = null.asInstanceOf[T] + + val thread = new Thread(threadName) { + override def run(): Unit = { + try { + result = body + } catch { + case NonFatal(e) => + exception = Some(e) + } + } + } + thread.setDaemon(isDaemon) + thread.start() + thread.join() + + exception match { + case Some(realException) => + // Remove the part of the stack that shows method calls into this helper method + // This means drop everything from the top until the stack element + // ThreadUtils.runInNewThread(), and then drop that as well (hence the `drop(1)`). + val baseStackTrace = Thread.currentThread().getStackTrace().dropWhile( + ! _.getClassName.contains(this.getClass.getSimpleName)).drop(1) + + // Remove the part of the new thread stack that shows methods call from this helper method + val extraStackTrace = realException.getStackTrace.takeWhile( + ! _.getClassName.contains(this.getClass.getSimpleName)) + + // Combine the two stack traces, with a place holder just specifying that there + // was a helper method used, without any further details of the helper + val placeHolderStackElem = new StackTraceElement( + s"... run in separate thread using ${ThreadUtils.getClass.getName.stripSuffix("$")} ..", + " ", "", -1) + val finalStackTrace = extraStackTrace ++ Seq(placeHolderStackElem) ++ baseStackTrace + + // Update the stack trace and rethrow the exception in the caller thread + realException.setStackTrace(finalStackTrace) + throw realException + case None => + result + } + } } diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala index 8c51e6b14b7fc..620e4debf4e08 100644 --- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala @@ -20,8 +20,9 @@ package org.apache.spark.util import java.util.concurrent.{CountDownLatch, TimeUnit} -import scala.concurrent.{Await, Future} import scala.concurrent.duration._ +import scala.concurrent.{Await, Future} +import scala.util.Random import org.apache.spark.SparkFunSuite @@ -66,4 +67,25 @@ class ThreadUtilsSuite extends SparkFunSuite { val futureThreadName = Await.result(f, 10.seconds) assert(futureThreadName === callerThreadName) } + + test("runInNewThread") { + import ThreadUtils._ + assert(runInNewThread("thread-name") { Thread.currentThread().getName } === "thread-name") + assert(runInNewThread("thread-name") { Thread.currentThread().isDaemon } === true) + assert( + runInNewThread("thread-name", isDaemon = false) { Thread.currentThread().isDaemon } === false + ) + val uniqueExceptionMessage = "test" + Random.nextInt() + val exception = intercept[IllegalArgumentException] { + runInNewThread("thread-name") { throw new IllegalArgumentException(uniqueExceptionMessage) } + } + assert(exception.asInstanceOf[IllegalArgumentException].getMessage === uniqueExceptionMessage) + assert(exception.getStackTrace.mkString("\n").contains( + "... run in separate thread using org.apache.spark.util.ThreadUtils ...") === true, + "stack trace does not contain expected place holder" + ) + assert(exception.getStackTrace.mkString("\n").contains("ThreadUtils.scala") === false, + "stack trace contains unexpected references to ThreadUtils" + ) + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index b496d1f341a0b..6720ba4f72cf3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -44,7 +44,7 @@ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver} import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener} import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} -import org.apache.spark.util.{CallSite, ShutdownHookManager, Utils} +import org.apache.spark.util.{CallSite, ShutdownHookManager, ThreadUtils} /** * Main entry point for Spark Streaming functionality. It provides methods used to create @@ -588,12 +588,20 @@ class StreamingContext private[streaming] ( state match { case INITIALIZED => startSite.set(DStream.getCreationSite()) - sparkContext.setCallSite(startSite.get) StreamingContext.ACTIVATION_LOCK.synchronized { StreamingContext.assertNoOtherContextIsActive() try { validate() - scheduler.start() + + // Start the streaming scheduler in a new thread, so that thread local properties + // like call sites and job groups can be reset without affecting those of the + // current thread. + ThreadUtils.runInNewThread("streaming-start") { + sparkContext.setCallSite(startSite.get) + sparkContext.clearJobGroup() + sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false") + scheduler.start() + } state = StreamingContextState.ACTIVE } catch { case NonFatal(e) => @@ -618,6 +626,7 @@ class StreamingContext private[streaming] ( } } + /** * Wait for the execution to stop. Any exceptions that occurs during the execution * will be thrown in this thread. diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index d26894e88fc26..3b9d0d15ea04c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -180,6 +180,38 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo assert(ssc.scheduler.isStarted === false) } + test("start should set job group and description of streaming jobs correctly") { + ssc = new StreamingContext(conf, batchDuration) + ssc.sc.setJobGroup("non-streaming", "non-streaming", true) + val sc = ssc.sc + + @volatile var jobGroupFound: String = "" + @volatile var jobDescFound: String = "" + @volatile var jobInterruptFound: String = "" + @volatile var allFound: Boolean = false + + addInputStream(ssc).foreachRDD { rdd => + jobGroupFound = sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) + jobDescFound = sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) + jobInterruptFound = sc.getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) + allFound = true + } + ssc.start() + + eventually(timeout(10 seconds), interval(10 milliseconds)) { + assert(allFound === true) + } + + // Verify streaming jobs have expected thread-local properties + assert(jobGroupFound === null) + assert(jobDescFound === null) + assert(jobInterruptFound === "false") + + // Verify current thread's thread-local properties have not changed + assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "non-streaming") + assert(sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) === "non-streaming") + assert(sc.getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) === "true") + } test("start multiple times") { ssc = new StreamingContext(master, appName, batchDuration) From 0494c80ef54f6f3a8c6f2d92abfe1a77a91df8b0 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 21 Sep 2015 18:06:45 -0700 Subject: [PATCH 369/802] [SPARK-10495] [SQL] Read date values in JSON data stored by Spark 1.5.0. https://issues.apache.org/jira/browse/SPARK-10681 Author: Yin Huai Closes #8806 from yhuai/SPARK-10495. --- .../datasources/json/JacksonGenerator.scala | 36 ++++++ .../datasources/json/JacksonParser.scala | 15 ++- .../datasources/json/JsonSuite.scala | 103 +++++++++++++++++- 3 files changed, 152 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index f65c7bbd6e29d..23bada1ddd92f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -73,6 +73,38 @@ private[sql] object JacksonGenerator { valWriter(field.dataType, v) } gen.writeEndObject() + + // For UDT, udt.serialize will produce SQL types. So, we need the following three cases. + case (ArrayType(ty, _), v: ArrayData) => + gen.writeStartArray() + v.foreach(ty, (_, value) => valWriter(ty, value)) + gen.writeEndArray() + + case (MapType(kt, vt, _), v: MapData) => + gen.writeStartObject() + v.foreach(kt, vt, { (k, v) => + gen.writeFieldName(k.toString) + valWriter(vt, v) + }) + gen.writeEndObject() + + case (StructType(ty), v: InternalRow) => + gen.writeStartObject() + var i = 0 + while (i < ty.length) { + val field = ty(i) + val value = v.get(i, field.dataType) + if (value != null) { + gen.writeFieldName(field.name) + valWriter(field.dataType, value) + } + i += 1 + } + gen.writeEndObject() + + case (dt, v) => + sys.error( + s"Failed to convert value $v (class of ${v.getClass}}) with the type of $dt to JSON.") } valWriter(rowSchema, row) @@ -133,6 +165,10 @@ private[sql] object JacksonGenerator { i += 1 } gen.writeEndObject() + + case (dt, v) => + sys.error( + s"Failed to convert value $v (class of ${v.getClass}}) with the type of $dt to JSON.") } valWriter(rowSchema, row) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index ff4d8c04e8eaf..c51140749c8e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -62,10 +62,23 @@ private[sql] object JacksonParser { // guard the non string type null + case (VALUE_STRING, BinaryType) => + parser.getBinaryValue + case (VALUE_STRING, DateType) => - DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(parser.getText).getTime) + val stringValue = parser.getText + if (stringValue.contains("-")) { + // The format of this string will probably be "yyyy-mm-dd". + DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(parser.getText).getTime) + } else { + // In Spark 1.5.0, we store the data as number of days since epoch in string. + // So, we just convert it to Int. + stringValue.toInt + } case (VALUE_STRING, TimestampType) => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681. DateTimeUtils.stringToTime(parser.getText).getTime * 1000L case (VALUE_NUMBER_INT, TimestampType) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 6a18cc6d27138..b614e6c4148fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -24,7 +24,7 @@ import com.fasterxml.jackson.core.JsonFactory import org.apache.spark.rdd.RDD import org.scalactic.Tolerance._ -import org.apache.spark.sql.{QueryTest, Row, SQLConf} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType @@ -1159,4 +1159,105 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9)) }) } + + test("backward compatibility") { + // This test we make sure our JSON support can read JSON data generated by previous version + // of Spark generated through toJSON method and JSON data source. + // The data is generated by the following program. + // Here are a few notes: + // - Spark 1.5.0 cannot save timestamp data. So, we manually added timestamp field (col13) + // in the JSON object. + // - For Spark before 1.5.1, we do not generate UDTs. So, we manually added the UDT value to + // JSON objects generated by those Spark versions (col17). + // - If the type is NullType, we do not write data out. + + // Create the schema. + val struct = + StructType( + StructField("f1", FloatType, true) :: + StructField("f2", ArrayType(BooleanType), true) :: Nil) + + val dataTypes = + Seq( + StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), MapType(StringType, LongType), struct, + new MyDenseVectorUDT()) + val fields = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, nullable = true) + } + val schema = StructType(fields) + + val constantValues = + Seq( + "a string in binary".getBytes("UTF-8"), + null, + true, + 1.toByte, + 2.toShort, + 3, + Long.MaxValue, + 0.25.toFloat, + 0.75, + new java.math.BigDecimal(s"1234.23456"), + new java.math.BigDecimal(s"1.23456"), + java.sql.Date.valueOf("2015-01-01"), + java.sql.Timestamp.valueOf("2015-01-01 23:50:59.123"), + Seq(2, 3, 4), + Map("a string" -> 2000L), + Row(4.75.toFloat, Seq(false, true)), + new MyDenseVector(Array(0.25, 2.25, 4.25))) + val data = + Row.fromSeq(Seq("Spark " + sqlContext.sparkContext.version) ++ constantValues) :: Nil + + // Data generated by previous versions. + // scalastyle:off + val existingJSONData = + """{"col0":"Spark 1.2.2","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.3.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.3.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.4.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.4.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.5.0","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: + """{"col0":"Spark 1.5.0","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"16436","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: Nil + // scalastyle:on + + // Generate data for the current version. + val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data, 1), schema) + withTempPath { path => + df.write.format("json").mode("overwrite").save(path.getCanonicalPath) + + // df.toJSON will convert internal rows to external rows first and then generate + // JSON objects. While, df.write.format("json") will write internal rows directly. + val allJSON = + existingJSONData ++ + df.toJSON.collect() ++ + sparkContext.textFile(path.getCanonicalPath).collect() + + Utils.deleteRecursively(path) + sparkContext.parallelize(allJSON, 1).saveAsTextFile(path.getCanonicalPath) + + // Read data back with the schema specified. + val col0Values = + Seq( + "Spark 1.2.2", + "Spark 1.3.1", + "Spark 1.3.1", + "Spark 1.4.1", + "Spark 1.4.1", + "Spark 1.5.0", + "Spark 1.5.0", + "Spark " + sqlContext.sparkContext.version, + "Spark " + sqlContext.sparkContext.version) + val expectedResult = col0Values.map { v => + Row.fromSeq(Seq(v) ++ constantValues) + } + checkAnswer( + sqlContext.read.format("json").schema(schema).load(path.getCanonicalPath), + expectedResult + ) + } + } } From c986e933a900602af47966bd41edb2116c421a39 Mon Sep 17 00:00:00 2001 From: Hossein Date: Mon, 21 Sep 2015 21:09:59 -0700 Subject: [PATCH 370/802] [SPARK-10711] [SPARKR] Do not assume spark.submit.deployMode is always set In ```RUtils.sparkRPackagePath()``` we 1. Call ``` sys.props("spark.submit.deployMode")``` which returns null if ```spark.submit.deployMode``` is not suet 2. Call ``` sparkConf.get("spark.submit.deployMode")``` which throws ```NoSuchElementException``` if ```spark.submit.deployMode``` is not set. This patch simply passes a default value ("cluster") for ```spark.submit.deployMode```. cc rxin Author: Hossein Closes #8832 from falaki/SPARK-10711. --- core/src/main/scala/org/apache/spark/api/r/RUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala index 9e807cc52f18c..fd5646b5b6372 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -44,7 +44,7 @@ private[spark] object RUtils { (sys.props("spark.master"), sys.props("spark.submit.deployMode")) } else { val sparkConf = SparkEnv.get.conf - (sparkConf.get("spark.master"), sparkConf.get("spark.submit.deployMode")) + (sparkConf.get("spark.master"), sparkConf.get("spark.submit.deployMode", "client")) } val isYarnCluster = master != null && master.contains("yarn") && deployMode == "cluster" From 1cd67415728e660a90e4dbe136272b5d6b8f1142 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 21 Sep 2015 23:21:24 -0700 Subject: [PATCH 371/802] [SPARK-9821] [PYSPARK] pyspark-reduceByKey-should-take-a-custom-partitioner from the issue: In Scala, I can supply a custom partitioner to reduceByKey (and other aggregation/repartitioning methods like aggregateByKey and combinedByKey), but as far as I can tell from the Pyspark API, there's no way to do the same in Python. Here's an example of my code in Scala: weblogs.map(s => (getFileType(s), 1)).reduceByKey(new FileTypePartitioner(),_+_) But I can't figure out how to do the same in Python. The closest I can get is to call repartition before reduceByKey like so: weblogs.map(lambda s: (getFileType(s), 1)).partitionBy(3,hash_filetype).reduceByKey(lambda v1,v2: v1+v2).collect() But that defeats the purpose, because I'm shuffling twice instead of once, so my performance is worse instead of better. Author: Holden Karau Closes #8569 from holdenk/SPARK-9821-pyspark-reduceByKey-should-take-a-custom-partitioner. --- python/pyspark/rdd.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 73d7d9a5692a9..56e892243c79c 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -686,7 +686,7 @@ def cartesian(self, other): other._jrdd_deserializer) return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer) - def groupBy(self, f, numPartitions=None): + def groupBy(self, f, numPartitions=None, partitionFunc=portable_hash): """ Return an RDD of grouped items. @@ -695,7 +695,7 @@ def groupBy(self, f, numPartitions=None): >>> sorted([(x, sorted(y)) for (x, y) in result]) [(0, [2, 8]), (1, [1, 1, 3, 5])] """ - return self.map(lambda x: (f(x), x)).groupByKey(numPartitions) + return self.map(lambda x: (f(x), x)).groupByKey(numPartitions, partitionFunc) @ignore_unicode_prefix def pipe(self, command, env=None, checkCode=False): @@ -1539,22 +1539,23 @@ def values(self): """ return self.map(lambda x: x[1]) - def reduceByKey(self, func, numPartitions=None): + def reduceByKey(self, func, numPartitions=None, partitionFunc=portable_hash): """ Merge the values for each key using an associative reduce function. This will also perform the merging locally on each mapper before sending results to a reducer, similarly to a "combiner" in MapReduce. - Output will be hash-partitioned with C{numPartitions} partitions, or + Output will be partitioned with C{numPartitions} partitions, or the default parallelism level if C{numPartitions} is not specified. + Default partitioner is hash-partition. >>> from operator import add >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> sorted(rdd.reduceByKey(add).collect()) [('a', 2), ('b', 1)] """ - return self.combineByKey(lambda x: x, func, func, numPartitions) + return self.combineByKey(lambda x: x, func, func, numPartitions, partitionFunc) def reduceByKeyLocally(self, func): """ @@ -1739,7 +1740,7 @@ def add_shuffle_key(split, iterator): # TODO: add control over map-side aggregation def combineByKey(self, createCombiner, mergeValue, mergeCombiners, - numPartitions=None): + numPartitions=None, partitionFunc=portable_hash): """ Generic function to combine the elements for each key using a custom set of aggregation functions. @@ -1777,7 +1778,7 @@ def combineLocally(iterator): return merger.items() locally_combined = self.mapPartitions(combineLocally, preservesPartitioning=True) - shuffled = locally_combined.partitionBy(numPartitions) + shuffled = locally_combined.partitionBy(numPartitions, partitionFunc) def _mergeCombiners(iterator): merger = ExternalMerger(agg, memory, serializer) @@ -1786,7 +1787,8 @@ def _mergeCombiners(iterator): return shuffled.mapPartitions(_mergeCombiners, preservesPartitioning=True) - def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): + def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None, + partitionFunc=portable_hash): """ Aggregate the values of each key, using given combine functions and a neutral "zero value". This function can return a different result type, U, than the type @@ -1800,9 +1802,9 @@ def createZero(): return copy.deepcopy(zeroValue) return self.combineByKey( - lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions) + lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions, partitionFunc) - def foldByKey(self, zeroValue, func, numPartitions=None): + def foldByKey(self, zeroValue, func, numPartitions=None, partitionFunc=portable_hash): """ Merge the values for each key using an associative function "func" and a neutral "zeroValue" which may be added to the result an @@ -1817,13 +1819,14 @@ def foldByKey(self, zeroValue, func, numPartitions=None): def createZero(): return copy.deepcopy(zeroValue) - return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions) + return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions, + partitionFunc) def _memory_limit(self): return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) # TODO: support variant with custom partitioner - def groupByKey(self, numPartitions=None): + def groupByKey(self, numPartitions=None, partitionFunc=portable_hash): """ Group the values for each key in the RDD into a single sequence. Hash-partitions the resulting RDD with numPartitions partitions. @@ -1859,7 +1862,7 @@ def combine(iterator): return merger.items() locally_combined = self.mapPartitions(combine, preservesPartitioning=True) - shuffled = locally_combined.partitionBy(numPartitions) + shuffled = locally_combined.partitionBy(numPartitions, partitionFunc) def groupByKey(it): merger = ExternalGroupBy(agg, memory, serializer) From bf20d6c9f9e478a5de24b45bbafd4dd89666c4cf Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 21 Sep 2015 23:29:59 -0700 Subject: [PATCH 372/802] [SPARK-10716] [BUILD] spark-1.5.0-bin-hadoop2.6.tgz file doesn't uncompress on OS X due to hidden file Remove ._SUCCESS.crc hidden file that may cause problems in distribution tar archive, and is not used Author: Sean Owen Closes #8846 from srowen/SPARK-10716. --- .../test_support/sql/orc_partitioned/._SUCCESS.crc | Bin 8 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 python/test_support/sql/orc_partitioned/._SUCCESS.crc diff --git a/python/test_support/sql/orc_partitioned/._SUCCESS.crc b/python/test_support/sql/orc_partitioned/._SUCCESS.crc deleted file mode 100644 index 3b7b044936a890cd8d651d349a752d819d71d22c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8 PcmYc;N@ieSU}69O2$TUk From 0180b849dbaf191826231eda7dfaaf146a19602b Mon Sep 17 00:00:00 2001 From: Jian Feng Date: Mon, 21 Sep 2015 23:36:41 -0700 Subject: [PATCH 373/802] [SPARK-10577] [PYSPARK] DataFrame hint for broadcast join https://issues.apache.org/jira/browse/SPARK-10577 Author: Jian Feng Closes #8801 from Jianfeng-chs/master. --- python/pyspark/sql/functions.py | 9 +++++++++ python/pyspark/sql/tests.py | 18 ++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 26b8662718a60..fa04f4cd83b6f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -29,6 +29,7 @@ from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType from pyspark.sql.column import Column, _to_java_column, _to_seq +from pyspark.sql.dataframe import DataFrame def _create_function(name, doc=""): @@ -189,6 +190,14 @@ def approxCountDistinct(col, rsd=None): return Column(jc) +@since(1.6) +def broadcast(df): + """Marks a DataFrame as small enough for use in broadcast joins.""" + + sc = SparkContext._active_spark_context + return DataFrame(sc._jvm.functions.broadcast(df._jdf), df.sql_ctx) + + @since(1.4) def coalesce(*cols): """Returns the first column that is not null. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 3e680f1030a71..645133b2b2d84 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1075,6 +1075,24 @@ def foo(): self.assertRaises(TypeError, foo) + # add test for SPARK-10577 (test broadcast join hint) + def test_functions_broadcast(self): + from pyspark.sql.functions import broadcast + + df1 = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) + df2 = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) + + # equijoin - should be converted into broadcast join + plan1 = df1.join(broadcast(df2), "key")._jdf.queryExecution().executedPlan() + self.assertEqual(1, plan1.toString().count("BroadcastHashJoin")) + + # no join key -- should not be a broadcast join + plan2 = df1.join(broadcast(df2))._jdf.queryExecution().executedPlan() + self.assertEqual(0, plan2.toString().count("BroadcastHashJoin")) + + # planner should not crash without a join + broadcast(df1)._jdf.queryExecution().executedPlan() + class HiveContextSQLTests(ReusedPySparkTestCase): From 781b21ba2a873ed29394c8dbc74fc700e3e0d17e Mon Sep 17 00:00:00 2001 From: Ewan Leith Date: Mon, 21 Sep 2015 23:43:20 -0700 Subject: [PATCH 374/802] [SPARK-10419] [SQL] Adding SQLServer support for datetimeoffset types to JdbcDialects Reading from Microsoft SQL Server over jdbc fails when the table contains datetimeoffset types. This patch registers a SQLServer JDBC Dialect that maps datetimeoffset to a String, as Microsoft suggest. Author: Ewan Leith Closes #8575 from realitymine-coordinator/sqlserver. --- .../apache/spark/sql/jdbc/JdbcDialects.scala | 18 ++++++++++++++++++ .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 1 + 2 files changed, 19 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 68ebaaca6c53d..c70fea1c3f50e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -137,6 +137,8 @@ object JdbcDialects { registerDialect(MySQLDialect) registerDialect(PostgresDialect) registerDialect(DB2Dialect) + registerDialect(MsSqlServerDialect) + /** * Fetch the JdbcDialect class corresponding to a given database url. @@ -260,3 +262,19 @@ case object DB2Dialect extends JdbcDialect { case _ => None } } + +/** + * :: DeveloperApi :: + * Default Microsoft SQL Server dialect, mapping the datetimeoffset types to a String on read. + */ +@DeveloperApi +case object MsSqlServerDialect extends JdbcDialect { + override def canHandle(url: String): Boolean = url.startsWith("jdbc:sqlserver") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (typeName.contains("datetimeoffset")) { + // String is recommend by Microsoft SQL Server for datetimeoffset types in non-MS clients + Some(StringType) + } else None + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 5ab9381de4d66..c4b039a9c5359 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -408,6 +408,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(JdbcDialects.get("jdbc:mysql://127.0.0.1/db") == MySQLDialect) assert(JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") == PostgresDialect) assert(JdbcDialects.get("jdbc:db2://127.0.0.1/db") == DB2Dialect) + assert(JdbcDialects.get("jdbc:sqlserver://127.0.0.1/db") == MsSqlServerDialect) assert(JdbcDialects.get("test.invalid") == NoopDialect) } From 1fcefef06950e2f03477282368ca835bbf40ff24 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 21 Sep 2015 23:46:00 -0700 Subject: [PATCH 375/802] [SPARK-10446][SQL] Support to specify join type when calling join with usingColumns JIRA: https://issues.apache.org/jira/browse/SPARK-10446 Currently the method `join(right: DataFrame, usingColumns: Seq[String])` only supports inner join. It is more convenient to have it support other join types. Author: Liang-Chi Hsieh Closes #8600 from viirya/usingcolumns_df. --- python/pyspark/sql/dataframe.py | 6 ++++- .../org/apache/spark/sql/DataFrame.scala | 22 ++++++++++++++++++- .../apache/spark/sql/DataFrameJoinSuite.scala | 13 +++++++++++ 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index fb995fa3a76b5..80f8d8a0eb37d 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -567,7 +567,11 @@ def join(self, other, on=None, how=None): if on is None or len(on) == 0: jdf = self._jdf.join(other._jdf) elif isinstance(on[0], basestring): - jdf = self._jdf.join(other._jdf, self._jseq(on)) + if how is None: + jdf = self._jdf.join(other._jdf, self._jseq(on), "inner") + else: + assert isinstance(how, basestring), "how should be basestring" + jdf = self._jdf.join(other._jdf, self._jseq(on), how) else: assert isinstance(on[0], Column), "on should be Column or list of Column" if len(on) > 1: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 8f737c2023931..ba94d77b2e60e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -484,6 +484,26 @@ class DataFrame private[sql]( * @since 1.4.0 */ def join(right: DataFrame, usingColumns: Seq[String]): DataFrame = { + join(right, usingColumns, "inner") + } + + /** + * Equi-join with another [[DataFrame]] using the given columns. + * + * Different from other join functions, the join columns will only appear once in the output, + * i.e. similar to SQL's `JOIN USING` syntax. + * + * Note that if you perform a self-join using this function without aliasing the input + * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since + * there is no way to disambiguate which side of the join you would like to reference. + * + * @param right Right side of the join operation. + * @param usingColumns Names of the columns to join on. This columns must exist on both sides. + * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. + * @group dfops + * @since 1.6.0 + */ + def join(right: DataFrame, usingColumns: Seq[String], joinType: String): DataFrame = { // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sqlContext.executePlan( @@ -502,7 +522,7 @@ class DataFrame private[sql]( Join( joined.left, joined.right, - joinType = Inner, + joinType = JoinType(joinType), condition) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index e2716d7841d85..56ad71ea4f487 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -42,6 +42,19 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil) } + test("join - join using multiple columns and specifying join type") { + val df = Seq(1, 2, 3).map(i => (i, i + 1, i.toString)).toDF("int", "int2", "str") + val df2 = Seq(1, 2, 3).map(i => (i, i + 1, (i + 1).toString)).toDF("int", "int2", "str") + + checkAnswer( + df.join(df2, Seq("int", "str"), "left"), + Row(1, 2, "1", null) :: Row(2, 3, "2", null) :: Row(3, 4, "3", null) :: Nil) + + checkAnswer( + df.join(df2, Seq("int", "str"), "right"), + Row(null, null, null, 2) :: Row(null, null, null, 3) :: Row(null, null, null, 4) :: Nil) + } + test("join - join using self join") { val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") From f24316e6d928c263cbf3872edd97982059c3db22 Mon Sep 17 00:00:00 2001 From: Madhusudanan Kandasamy Date: Tue, 22 Sep 2015 00:03:48 -0700 Subject: [PATCH 376/802] [SPARK-10458] [SPARK CORE] Added isStopped() method in SparkContext Added isStopped() method in SparkContext Author: Madhusudanan Kandasamy Closes #8749 from kmadhugit/SPARK-10458. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ebd8e946ee7a2..967fec9f42bcf 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -265,6 +265,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val tachyonFolderName = externalBlockStoreFolderName def isLocal: Boolean = (master == "local" || master.startsWith("local[")) + /** + * @return true if context is stopped or in the midst of stopping. + */ + def isStopped: Boolean = stopped.get() // An asynchronous listener bus for Spark events private[spark] val listenerBus = new LiveListenerBus From fd61b004877ba4d51c95cd0e08f53bffdf106395 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 22 Sep 2015 00:05:30 -0700 Subject: [PATCH 377/802] [Minor] style fix for previous commit f24316e --- core/src/main/scala/org/apache/spark/SparkContext.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 967fec9f42bcf..bf3aeb488d597 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -265,6 +265,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val tachyonFolderName = externalBlockStoreFolderName def isLocal: Boolean = (master == "local" || master.startsWith("local[")) + /** * @return true if context is stopped or in the midst of stopping. */ From 4da32bc0e747fefe847bffe493785d4d16069c04 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 22 Sep 2015 00:07:30 -0700 Subject: [PATCH 378/802] [SPARK-8567] [SQL] Increase the timeout of o.a.s.sql.hive.HiveSparkSubmitSuite to 5 minutes. https://issues.apache.org/jira/browse/SPARK-8567 Looks like "SPARK-8368: includes jars passed in through --jars" is pretty flaky now. Based on some history runs, the time spent on a successful run may be from 1.5 minutes to almost 3 minutes. Let's try to increase the timeout and see if we can fix this test. https://amplab.cs.berkeley.edu/jenkins/job/Spark-1.5-SBT/AMPLAB_JENKINS_BUILD_PROFILE=hadoop2.0,label=spark-test/385/testReport/junit/org.apache.spark.sql.hive/HiveSparkSubmitSuite/SPARK_8368__includes_jars_passed_in_through___jars/history/?start=25 Author: Yin Huai Closes #8850 from yhuai/SPARK-8567-anotherTry. --- .../scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 97df249bdb6d6..5f1660b62d418 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -139,7 +139,7 @@ class HiveSparkSubmitSuite new ProcessOutputCapturer(process.getErrorStream, captureOutput("stderr")).start() try { - val exitCode = failAfter(180.seconds) { process.waitFor() } + val exitCode = failAfter(300.seconds) { process.waitFor() } if (exitCode != 0) { // include logs in output. Note that logging is async and may not have completed // at the time this exception is raised From f3b727c801408b1cd50e5d9463f2fe0fce654a16 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 22 Sep 2015 00:09:29 -0700 Subject: [PATCH 379/802] [SQL] [MINOR] map -> foreach. DataFrame.explain should use foreach to print the explain content. Author: Reynold Xin Closes #8862 from rxin/map-foreach. --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index ba94d77b2e60e..a11140b717360 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -320,9 +320,8 @@ class DataFrame private[sql]( * @since 1.3.0 */ def explain(extended: Boolean): Unit = { - ExplainCommand( - queryExecution.logical, - extended = extended).queryExecution.executedPlan.executeCollect().map { + val explain = ExplainCommand(queryExecution.logical, extended = extended) + explain.queryExecution.executedPlan.executeCollect().foreach { // scalastyle:off println r => println(r.getString(0)) // scalastyle:on println From 0bd0e5bed2176b119b3ada590993e153757ea09b Mon Sep 17 00:00:00 2001 From: Akash Mishra Date: Tue, 22 Sep 2015 00:14:27 -0700 Subject: [PATCH 380/802] =?UTF-8?q?[SPARK-10695]=20[DOCUMENTATION]=20[MESO?= =?UTF-8?q?S]=20Fixing=20incorrect=20value=20informati=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …on for spark.mesos.constraints parameter. Author: Akash Mishra Closes #8816 from SleepyThread/constraint-fix. --- docs/running-on-mesos.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 460a66f37dd64..ec5a44d79212b 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -189,10 +189,10 @@ using `conf.set("spark.cores.max", "10")` (for example). You may also make use of `spark.mesos.constraints` to set attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. {% highlight scala %} -conf.set("spark.mesos.constraints", "tachyon=true;us-east-1=false") +conf.set("spark.mesos.constraints", "tachyon:true;us-east-1:false") {% endhighlight %} -For example, Let's say `spark.mesos.constraints` is set to `tachyon=true;us-east-1=false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. +For example, Let's say `spark.mesos.constraints` is set to `tachyon:true;us-east-1:false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. # Mesos Docker Support From 7278f792a73bbcf8d68f38dc2d87cf722693c4cf Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Tue, 22 Sep 2015 11:03:21 +0100 Subject: [PATCH 381/802] [SPARK-10718] [BUILD] Update License on conf files and corresponding excludes file update Update License on conf files and corresponding excludes file update Author: Rekha Joshi Author: Joshi Closes #8842 from rekhajoshm/SPARK-10718. --- .rat-excludes | 12 ------------ conf/docker.properties.template | 17 +++++++++++++++++ conf/fairscheduler.xml.template | 18 ++++++++++++++++++ conf/log4j.properties.template | 17 +++++++++++++++++ conf/metrics.properties.template | 17 +++++++++++++++++ conf/slaves.template | 17 +++++++++++++++++ conf/spark-defaults.conf.template | 17 +++++++++++++++++ conf/spark-env.sh.template | 17 +++++++++++++++++ .../spark/log4j-defaults-repl.properties | 17 +++++++++++++++++ .../org/apache/spark/log4j-defaults.properties | 17 +++++++++++++++++ 10 files changed, 154 insertions(+), 12 deletions(-) diff --git a/.rat-excludes b/.rat-excludes index 9165872b9fb27..08fba6d351d6a 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -15,20 +15,8 @@ TAGS RELEASE control docs -docker.properties.template -fairscheduler.xml.template -spark-defaults.conf.template -log4j.properties -log4j.properties.template -metrics.properties -metrics.properties.template slaves -slaves.template -spark-env.sh spark-env.cmd -spark-env.sh.template -log4j-defaults.properties -log4j-defaults-repl.properties bootstrap-tooltip.js jquery-1.11.1.min.js d3.min.js diff --git a/conf/docker.properties.template b/conf/docker.properties.template index 26e3bfd9c5b9b..55cb094b4af46 100644 --- a/conf/docker.properties.template +++ b/conf/docker.properties.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + spark.mesos.executor.docker.image: spark.mesos.executor.docker.volumes: /usr/local/lib:/host/usr/local/lib:ro spark.mesos.executor.home: /opt/spark diff --git a/conf/fairscheduler.xml.template b/conf/fairscheduler.xml.template index acf59e2a35986..385b2e772d2c8 100644 --- a/conf/fairscheduler.xml.template +++ b/conf/fairscheduler.xml.template @@ -1,4 +1,22 @@ + + + FAIR diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index 74c5cea94403a..f3046be54d7c6 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Set everything to be logged to the console log4j.rootCategory=INFO, console log4j.appender.console=org.apache.log4j.ConsoleAppender diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index 7f17bc7eea4f5..d6962e0da2f30 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # syntax: [instance].sink|source.[name].[options]=[value] # This file configures Spark's internal metrics system. The metrics system is diff --git a/conf/slaves.template b/conf/slaves.template index da0a01343d20a..be42a638230b7 100644 --- a/conf/slaves.template +++ b/conf/slaves.template @@ -1,2 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # A Spark Worker will be started on each of the machines listed below. localhost \ No newline at end of file diff --git a/conf/spark-defaults.conf.template b/conf/spark-defaults.conf.template index a48dcc70e1363..19cba6e71ed19 100644 --- a/conf/spark-defaults.conf.template +++ b/conf/spark-defaults.conf.template @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Default system properties included when running spark-submit. # This is useful for setting default environmental settings. diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index c05fe381a36a7..990ded420be72 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -1,5 +1,22 @@ #!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # This file is sourced when running various Spark programs. # Copy it as spark-env.sh and edit that to configure Spark for your site. diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties index 689afea64f8db..c85abc35b93bf 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Set everything to be logged to the console log4j.rootCategory=WARN, console log4j.appender.console=org.apache.log4j.ConsoleAppender diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index 27006e45e932b..d44cc85dcbd82 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Set everything to be logged to the console log4j.rootCategory=INFO, console log4j.appender.console=org.apache.log4j.ConsoleAppender From 870b8a2edd44c9274c43ca0db4ef5b0998e16fd8 Mon Sep 17 00:00:00 2001 From: Meihua Wu Date: Tue, 22 Sep 2015 11:05:24 +0100 Subject: [PATCH 382/802] [SPARK-10706] [MLLIB] Add java wrapper for random vector rdd Add java wrapper for random vector rdd holdenk srowen Author: Meihua Wu Closes #8841 from rotationsymmetry/SPARK-10706. --- .../spark/mllib/random/RandomRDDs.scala | 42 +++++++++++++++++++ .../mllib/random/JavaRandomRDDsSuite.java | 17 ++++++++ 2 files changed, 59 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index f8ff26b5795be..41d7c4d355f61 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -855,6 +855,48 @@ object RandomRDDs { sc, numRows, numCols, numPartitionsOrDefault(sc, numPartitions), generator, seed) } + /** + * Java-friendly version of [[RandomRDDs#randomVectorRDD]]. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaVectorRDD( + jsc: JavaSparkContext, + generator: RandomDataGenerator[Double], + numRows: Long, + numCols: Int, + numPartitions: Int, + seed: Long): JavaRDD[Vector] = { + randomVectorRDD(jsc.sc, generator, numRows, numCols, numPartitions, seed).toJavaRDD() + } + + /** + * [[RandomRDDs#randomJavaVectorRDD]] with the default seed. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaVectorRDD( + jsc: JavaSparkContext, + generator: RandomDataGenerator[Double], + numRows: Long, + numCols: Int, + numPartitions: Int): JavaRDD[Vector] = { + randomVectorRDD(jsc.sc, generator, numRows, numCols, numPartitions).toJavaRDD() + } + + /** + * [[RandomRDDs#randomJavaVectorRDD]] with the default number of partitions and the default seed. + */ + @DeveloperApi + @Since("1.6.0") + def randomJavaVectorRDD( + jsc: JavaSparkContext, + generator: RandomDataGenerator[Double], + numRows: Long, + numCols: Int): JavaRDD[Vector] = { + randomVectorRDD(jsc.sc, generator, numRows, numCols).toJavaRDD() + } + /** * Returns `numPartitions` if it is positive, or `sc.defaultParallelism` otherwise. */ diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java index fce5f6712f462..5728df5aeebdc 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java @@ -246,6 +246,23 @@ public void testArbitrary() { Assert.assertEquals(2, rdd.first().length()); } } + + @Test + @SuppressWarnings("unchecked") + public void testRandomVectorRDD() { + UniformGenerator generator = new UniformGenerator(); + long m = 100L; + int n = 10; + int p = 2; + long seed = 1L; + JavaRDD rdd1 = randomJavaVectorRDD(sc, generator, m, n); + JavaRDD rdd2 = randomJavaVectorRDD(sc, generator, m, n, p); + JavaRDD rdd3 = randomJavaVectorRDD(sc, generator, m, n, p, seed); + for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(m, rdd.count()); + Assert.assertEquals(n, rdd.first().size()); + } + } } // This is just a test generator, it always returns a string of 42 From f4a3c4e34ce93bcaf29c0a35573932880a8b792b Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 22 Sep 2015 10:19:08 -0700 Subject: [PATCH 383/802] [SPARK-9962] [ML] Decision Tree training: prevNodeIdsForInstances.unpersist() at end of training NodeIdCache: prevNodeIdsForInstances.unpersist() needs to be called at end of training. Author: Holden Karau Closes #8541 from holdenk/SPARK-9962-decission-tree-training-prevNodeIdsForiNstances-unpersist-at-end-of-training. --- .../scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala | 8 ++++---- .../org/apache/spark/mllib/tree/impl/NodeIdCache.scala | 4 ++++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala index 488e8e4fb5dcd..c5ad8df73fac9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala @@ -164,10 +164,10 @@ private[spark] class NodeIdCache( } } } - } - if (prevNodeIdsForInstances != null) { - // Unpersist the previous one if one exists. - prevNodeIdsForInstances.unpersist() + if (prevNodeIdsForInstances != null) { + // Unpersist the previous one if one exists. + prevNodeIdsForInstances.unpersist() + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala index 8f9eb24b57b55..0abed5411143d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala @@ -166,6 +166,10 @@ private[spark] class NodeIdCache( fs.delete(new Path(old.getCheckpointFile.get), true) } } + if (prevNodeIdsForInstances != null) { + // Unpersist the previous one if one exists. + prevNodeIdsForInstances.unpersist() + } } } From 7104ee0e5dc1290b8b845a0a8ddcdb1875cfd060 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 22 Sep 2015 11:00:33 -0700 Subject: [PATCH 384/802] [SPARK-10750] [ML] ML Param validate should print better error information Currently when you set illegal value for params of array type (such as IntArrayParam, DoubleArrayParam, StringArrayParam), it will throw IllegalArgumentException but with incomprehensible error information. Take ```VectorSlicer.setNames``` as an example: ```scala val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result") // The value of setNames must be contain distinct elements, so the next line will throw exception. vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4", "f1")) ``` It will throw IllegalArgumentException as: ``` vectorSlicer_b3b4d1a10f43 parameter names given invalid value [Ljava.lang.String;798256c5. java.lang.IllegalArgumentException: vectorSlicer_b3b4d1a10f43 parameter names given invalid value [Ljava.lang.String;798256c5. ``` We should distinguish the value of array type from primitive type at Param.validate(value: T), and we will get better error information. ``` vectorSlicer_3b744ea277b2 parameter names given invalid value [f1,f4,f1]. java.lang.IllegalArgumentException: vectorSlicer_3b744ea277b2 parameter names given invalid value [f1,f4,f1]. ``` Author: Yanbo Liang Closes #8863 from yanboliang/spark-10750. --- .../src/main/scala/org/apache/spark/ml/param/params.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index de32b7218c277..48f6269e57e98 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -65,7 +65,12 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali */ private[param] def validate(value: T): Unit = { if (!isValid(value)) { - throw new IllegalArgumentException(s"$parent parameter $name given invalid value $value.") + val valueToString = value match { + case v: Array[_] => v.mkString("[", ",", "]") + case _ => value.toString + } + throw new IllegalArgumentException( + s"$parent parameter $name given invalid value $valueToString.") } } From 2ea0f2e11b82ef4817c7e6a162ea23da7860b893 Mon Sep 17 00:00:00 2001 From: xutingjun Date: Tue, 22 Sep 2015 11:01:32 -0700 Subject: [PATCH 385/802] [SPARK-9585] Delete the input format caching because some input format are non thread safe If we cache the InputFormat, all tasks on the same executor will share it. Some InputFormat is thread safety, but some are not, such as HiveHBaseTableInputFormat. If tasks share a non thread safe InputFormat, unexpected error may be occurs. To avoid it, I think we should delete the input format caching. Author: xutingjun Author: meiyoula <1039320815@qq.com> Author: Xutingjun Closes #7918 from XuTingjun/cached_inputFormat. --- core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala | 6 ------ 1 file changed, 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 8f2655d63b797..77b57132b9f1f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -182,17 +182,11 @@ class HadoopRDD[K, V]( } protected def getInputFormat(conf: JobConf): InputFormat[K, V] = { - if (HadoopRDD.containsCachedMetadata(inputFormatCacheKey)) { - return HadoopRDD.getCachedMetadata(inputFormatCacheKey).asInstanceOf[InputFormat[K, V]] - } - // Once an InputFormat for this RDD is created, cache it so that only one reflection call is - // done in each local process. val newInputFormat = ReflectionUtils.newInstance(inputFormatClass.asInstanceOf[Class[_]], conf) .asInstanceOf[InputFormat[K, V]] if (newInputFormat.isInstanceOf[Configurable]) { newInputFormat.asInstanceOf[Configurable].setConf(conf) } - HadoopRDD.putCachedMetadata(inputFormatCacheKey, newInputFormat) newInputFormat } From 22d40159e60dd27a428e4051ef607292cbffbff3 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 22 Sep 2015 11:07:01 -0700 Subject: [PATCH 386/802] [SPARK-10593] [SQL] fix resolve output of Generate The output of Generate should not be resolved as Reference. Author: Davies Liu Closes #8755 from davies/view. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 16 ++++++++++++++++ .../spark/sql/catalyst/plans/QueryPlan.scala | 1 - .../catalyst/plans/logical/basicOperators.scala | 2 +- .../spark/sql/hive/execution/SQLQuerySuite.scala | 14 ++++++++++++++ 4 files changed, 31 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 02f34cbf58ad0..bf72d47ce1ea6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -378,6 +378,22 @@ class Analyzer( val newOrdering = resolveSortOrders(ordering, child, throws = false) Sort(newOrdering, global, child) + // A special case for Generate, because the output of Generate should not be resolved by + // ResolveReferences. Attributes in the output will be resolved by ResolveGenerate. + case g @ Generate(generator, join, outer, qualifier, output, child) + if child.resolved && !generator.resolved => + val newG = generator transformUp { + case u @ UnresolvedAttribute(nameParts) => + withPosition(u) { child.resolve(nameParts, resolver).getOrElse(u) } + case UnresolvedExtractValue(child, fieldExpr) => + ExtractValue(child, fieldExpr, resolver) + } + if (newG.fastEquals(generator)) { + g + } else { + Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child) + } + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressionsUp { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 55286f9f2fc5c..0ec9f08571082 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, VirtualColumn} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types.{DataType, StructType} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 722f69cdca827..ae9482c10f126 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -68,7 +68,7 @@ case class Generate( generator.resolved && childrenResolved && generator.elementTypes.length == generatorOutput.length && - !generatorOutput.exists(!_.resolved) + generatorOutput.forall(_.resolved) } // we don't want the gOutput to be taken as part of the expressions diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 8126d02335217..bb02473dd17ca 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1170,4 +1170,18 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(sqlContext.table("`db.t`"), df) } } + + test("SPARK-10593 same column names in lateral view") { + val df = sqlContext.sql( + """ + |select + |insideLayer2.json as a2 + |from (select '{"layer1": {"layer2": "text inside layer 2"}}' json) test + |lateral view json_tuple(json, 'layer1') insideLayer1 as json + |lateral view json_tuple(insideLayer1.json, 'layer2') insideLayer2 as json + """.stripMargin + ) + + checkAnswer(df, Row("text inside layer 2") :: Nil) + } } From 1ca5e2e0b8d8d406c02a74c76ae9d7fc5637c8d3 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 22 Sep 2015 11:50:22 -0700 Subject: [PATCH 387/802] [SPARK-10704] Rename HashShuffleReader to BlockStoreShuffleReader The current shuffle code has an interface named ShuffleReader with only one implementation, HashShuffleReader. This naming is confusing, since the same read path code is used for both sort- and hash-based shuffle. This patch addresses this by renaming HashShuffleReader to BlockStoreShuffleReader. Author: Josh Rosen Closes #8825 from JoshRosen/shuffle-reader-cleanup. --- ...shShuffleReader.scala => BlockStoreShuffleReader.scala} | 5 ++--- .../org/apache/spark/shuffle/hash/HashShuffleManager.scala | 2 +- .../org/apache/spark/shuffle/sort/SortShuffleManager.scala | 3 +-- ...eaderSuite.scala => BlockStoreShuffleReaderSuite.scala} | 7 +++---- 4 files changed, 7 insertions(+), 10 deletions(-) rename core/src/main/scala/org/apache/spark/shuffle/{hash/HashShuffleReader.scala => BlockStoreShuffleReader.scala} (97%) rename core/src/test/scala/org/apache/spark/shuffle/{hash/HashShuffleReaderSuite.scala => BlockStoreShuffleReaderSuite.scala} (96%) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala similarity index 97% rename from core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala rename to core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 0c8f08f0f3b1b..6dc9a16e58531 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -15,16 +15,15 @@ * limitations under the License. */ -package org.apache.spark.shuffle.hash +package org.apache.spark.shuffle import org.apache.spark._ import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter -private[spark] class HashShuffleReader[K, C]( +private[spark] class BlockStoreShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], startPartition: Int, endPartition: Int, diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala index 0b46634b8b466..d2e2fc4c110a7 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -51,7 +51,7 @@ private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager startPartition: Int, endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { - new HashShuffleReader( + new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 476cc1f303da7..9df4e551669cc 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -21,7 +21,6 @@ import java.util.concurrent.ConcurrentHashMap import org.apache.spark.{Logging, SparkConf, TaskContext, ShuffleDependency} import org.apache.spark.shuffle._ -import org.apache.spark.shuffle.hash.HashShuffleReader private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { @@ -54,7 +53,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { // We currently use the same block store shuffle fetcher as the hash-based shuffle. - new HashShuffleReader( + new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) } diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala similarity index 96% rename from core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala rename to core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 05b3afef5b839..a5eafb1b5529e 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.shuffle.hash +package org.apache.spark.shuffle import java.io.{ByteArrayOutputStream, InputStream} import java.nio.ByteBuffer @@ -28,7 +28,6 @@ import org.mockito.stubbing.Answer import org.apache.spark._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.shuffle.BaseShuffleHandle import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} /** @@ -56,7 +55,7 @@ class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends Managed } } -class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { +class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { /** * This test makes sure that, when data is read from a HashShuffleReader, the underlying @@ -134,7 +133,7 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { new BaseShuffleHandle(shuffleId, numMaps, dependency) } - val shuffleReader = new HashShuffleReader( + val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, reduceId, reduceId + 1, From 5017c685f484ec256101d1d33bad11d9e0c0f641 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 22 Sep 2015 12:14:15 -0700 Subject: [PATCH 388/802] [SPARK-10740] [SQL] handle nondeterministic expressions correctly for set operations https://issues.apache.org/jira/browse/SPARK-10740 Author: Wenchen Fan Closes #8858 from cloud-fan/non-deter. --- .../sql/catalyst/optimizer/Optimizer.scala | 69 ++++++++++++++----- .../optimizer/SetOperationPushDownSuite.scala | 3 +- .../org/apache/spark/sql/DataFrameSuite.scala | 41 +++++++++++ 3 files changed, 93 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 324f40a051c38..63602eaa8ccd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -95,14 +95,14 @@ object SamplePushDown extends Rule[LogicalPlan] { * Intersect: * It is not safe to pushdown Projections through it because we need to get the * intersect of rows by comparing the entire rows. It is fine to pushdown Filters - * because we will not have non-deterministic expressions. + * with deterministic condition. * * Except: * It is not safe to pushdown Projections through it because we need to get the * intersect of rows by comparing the entire rows. It is fine to pushdown Filters - * because we will not have non-deterministic expressions. + * with deterministic condition. */ -object SetOperationPushDown extends Rule[LogicalPlan] { +object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { /** * Maps Attributes from the left side to the corresponding Attribute on the right side. @@ -129,34 +129,65 @@ object SetOperationPushDown extends Rule[LogicalPlan] { result.asInstanceOf[A] } + /** + * Splits the condition expression into small conditions by `And`, and partition them by + * deterministic, and finally recombine them by `And`. It returns an expression containing + * all deterministic expressions (the first field of the returned Tuple2) and an expression + * containing all non-deterministic expressions (the second field of the returned Tuple2). + */ + private def partitionByDeterministic(condition: Expression): (Expression, Expression) = { + val andConditions = splitConjunctivePredicates(condition) + andConditions.partition(_.deterministic) match { + case (deterministic, nondeterministic) => + deterministic.reduceOption(And).getOrElse(Literal(true)) -> + nondeterministic.reduceOption(And).getOrElse(Literal(true)) + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Push down filter into union case Filter(condition, u @ Union(left, right)) => + val (deterministic, nondeterministic) = partitionByDeterministic(condition) val rewrites = buildRewrites(u) - Union( - Filter(condition, left), - Filter(pushToRight(condition, rewrites), right)) - - // Push down projection through UNION ALL - case Project(projectList, u @ Union(left, right)) => - val rewrites = buildRewrites(u) - Union( - Project(projectList, left), - Project(projectList.map(pushToRight(_, rewrites)), right)) + Filter(nondeterministic, + Union( + Filter(deterministic, left), + Filter(pushToRight(deterministic, rewrites), right) + ) + ) + + // Push down deterministic projection through UNION ALL + case p @ Project(projectList, u @ Union(left, right)) => + if (projectList.forall(_.deterministic)) { + val rewrites = buildRewrites(u) + Union( + Project(projectList, left), + Project(projectList.map(pushToRight(_, rewrites)), right)) + } else { + p + } // Push down filter through INTERSECT case Filter(condition, i @ Intersect(left, right)) => + val (deterministic, nondeterministic) = partitionByDeterministic(condition) val rewrites = buildRewrites(i) - Intersect( - Filter(condition, left), - Filter(pushToRight(condition, rewrites), right)) + Filter(nondeterministic, + Intersect( + Filter(deterministic, left), + Filter(pushToRight(deterministic, rewrites), right) + ) + ) // Push down filter through EXCEPT case Filter(condition, e @ Except(left, right)) => + val (deterministic, nondeterministic) = partitionByDeterministic(condition) val rewrites = buildRewrites(e) - Except( - Filter(condition, left), - Filter(pushToRight(condition, rewrites), right)) + Filter(nondeterministic, + Except( + Filter(deterministic, left), + Filter(pushToRight(deterministic, rewrites), right) + ) + ) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala index 3fca47a023dc6..1595ad9327423 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala @@ -30,7 +30,8 @@ class SetOperationPushDownSuite extends PlanTest { Batch("Subqueries", Once, EliminateSubQueries) :: Batch("Union Pushdown", Once, - SetOperationPushDown) :: Nil + SetOperationPushDown, + SimplifyFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1370713975f2f..d919877746c72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -916,4 +916,45 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(intersect.count() === 30) assert(except.count() === 70) } + + test("SPARK-10740: handle nondeterministic expressions correctly for set operations") { + val df1 = (1 to 20).map(Tuple1.apply).toDF("i") + val df2 = (1 to 10).map(Tuple1.apply).toDF("i") + + // When generating expected results at here, we need to follow the implementation of + // Rand expression. + def expected(df: DataFrame): Seq[Row] = { + df.rdd.collectPartitions().zipWithIndex.flatMap { + case (data, index) => + val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) + data.filter(_.getInt(0) < rng.nextDouble() * 10) + } + } + + val union = df1.unionAll(df2) + checkAnswer( + union.filter('i < rand(7) * 10), + expected(union) + ) + checkAnswer( + union.select(rand(7)), + union.rdd.collectPartitions().zipWithIndex.flatMap { + case (data, index) => + val rng = new org.apache.spark.util.random.XORShiftRandom(7 + index) + data.map(_ => rng.nextDouble()).map(i => Row(i)) + } + ) + + val intersect = df1.intersect(df2) + checkAnswer( + intersect.filter('i < rand(7) * 10), + expected(intersect) + ) + + val except = df1.except(df2) + checkAnswer( + except.filter('i < rand(7) * 10), + expected(except) + ) + } } From 2204cdb28483b249616068085d4e88554fe6acef Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 22 Sep 2015 13:29:39 -0700 Subject: [PATCH 389/802] [SPARK-10672] [SQL] Do not fail when we cannot save the metadata of a data source table in a hive compatible way https://issues.apache.org/jira/browse/SPARK-10672 With changes in this PR, we will fallback to same the metadata of a table in Spark SQL specific way if we fail to save it in a hive compatible way (Hive throws an exception because of its internal restrictions, e.g. binary and decimal types cannot be saved to parquet if the metastore is running Hive 0.13). I manually tested the fix with the following test in `DataSourceWithHiveMetastoreCatalogSuite` (`spark.sql.hive.metastore.version=0.13` and `spark.sql.hive.metastore.jars`=`maven`). ``` test(s"fail to save metadata of a parquet table in hive 0.13") { withTempPath { dir => withTable("t") { val path = dir.getCanonicalPath sql( s"""CREATE TABLE t USING $provider |OPTIONS (path '$path') |AS SELECT 1 AS d1, cast("val_1" as binary) AS d2 """.stripMargin) sql( s"""describe formatted t """.stripMargin).collect.foreach(println) sqlContext.table("t").show } } } } ``` Without this fix, we will fail with the following error. ``` org.apache.hadoop.hive.ql.metadata.HiveException: java.lang.UnsupportedOperationException: Unknown field type: binary at org.apache.hadoop.hive.ql.metadata.Hive.createTable(Hive.java:619) at org.apache.hadoop.hive.ql.metadata.Hive.createTable(Hive.java:576) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$createTable$1.apply$mcV$sp(ClientWrapper.scala:359) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$createTable$1.apply(ClientWrapper.scala:357) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$createTable$1.apply(ClientWrapper.scala:357) at org.apache.spark.sql.hive.client.ClientWrapper$$anonfun$withHiveState$1.apply(ClientWrapper.scala:256) at org.apache.spark.sql.hive.client.ClientWrapper.retryLocked(ClientWrapper.scala:211) at org.apache.spark.sql.hive.client.ClientWrapper.withHiveState(ClientWrapper.scala:248) at org.apache.spark.sql.hive.client.ClientWrapper.createTable(ClientWrapper.scala:357) at org.apache.spark.sql.hive.HiveMetastoreCatalog.createDataSourceTable(HiveMetastoreCatalog.scala:358) at org.apache.spark.sql.hive.execution.CreateMetastoreDataSourceAsSelect.run(commands.scala:285) at org.apache.spark.sql.execution.ExecutedCommand.sideEffectResult$lzycompute(commands.scala:57) at org.apache.spark.sql.execution.ExecutedCommand.sideEffectResult(commands.scala:57) at org.apache.spark.sql.execution.ExecutedCommand.doExecute(commands.scala:69) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$5.apply(SparkPlan.scala:140) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$5.apply(SparkPlan.scala:138) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:150) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:138) at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:58) at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:58) at org.apache.spark.sql.DataFrame.(DataFrame.scala:144) at org.apache.spark.sql.DataFrame.(DataFrame.scala:129) at org.apache.spark.sql.DataFrame$.apply(DataFrame.scala:51) at org.apache.spark.sql.SQLContext.sql(SQLContext.scala:725) at org.apache.spark.sql.test.SQLTestUtils$$anonfun$sql$1.apply(SQLTestUtils.scala:56) at org.apache.spark.sql.test.SQLTestUtils$$anonfun$sql$1.apply(SQLTestUtils.scala:56) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1$$anonfun$apply$mcV$sp$2$$anonfun$apply$2.apply$mcV$sp(HiveMetastoreCatalogSuite.scala:165) at org.apache.spark.sql.test.SQLTestUtils$class.withTable(SQLTestUtils.scala:150) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite.withTable(HiveMetastoreCatalogSuite.scala:52) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1$$anonfun$apply$mcV$sp$2.apply(HiveMetastoreCatalogSuite.scala:162) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1$$anonfun$apply$mcV$sp$2.apply(HiveMetastoreCatalogSuite.scala:161) at org.apache.spark.sql.test.SQLTestUtils$class.withTempPath(SQLTestUtils.scala:125) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite.withTempPath(HiveMetastoreCatalogSuite.scala:52) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1.apply$mcV$sp(HiveMetastoreCatalogSuite.scala:161) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1.apply(HiveMetastoreCatalogSuite.scala:161) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite$$anonfun$4$$anonfun$apply$1.apply(HiveMetastoreCatalogSuite.scala:161) at org.scalatest.Transformer$$anonfun$apply$1.apply$mcV$sp(Transformer.scala:22) at org.scalatest.OutcomeOf$class.outcomeOf(OutcomeOf.scala:85) at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104) at org.scalatest.Transformer.apply(Transformer.scala:22) at org.scalatest.Transformer.apply(Transformer.scala:20) at org.scalatest.FunSuiteLike$$anon$1.apply(FunSuiteLike.scala:166) at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:42) at org.scalatest.FunSuiteLike$class.invokeWithFixture$1(FunSuiteLike.scala:163) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175) at org.scalatest.FunSuiteLike$$anonfun$runTest$1.apply(FunSuiteLike.scala:175) at org.scalatest.SuperEngine.runTestImpl(Engine.scala:306) at org.scalatest.FunSuiteLike$class.runTest(FunSuiteLike.scala:175) at org.scalatest.FunSuite.runTest(FunSuite.scala:1555) at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:208) at org.scalatest.FunSuiteLike$$anonfun$runTests$1.apply(FunSuiteLike.scala:208) at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:413) at org.scalatest.SuperEngine$$anonfun$traverseSubNodes$1$1.apply(Engine.scala:401) at scala.collection.immutable.List.foreach(List.scala:318) at org.scalatest.SuperEngine.traverseSubNodes$1(Engine.scala:401) at org.scalatest.SuperEngine.org$scalatest$SuperEngine$$runTestsInBranch(Engine.scala:396) at org.scalatest.SuperEngine.runTestsImpl(Engine.scala:483) at org.scalatest.FunSuiteLike$class.runTests(FunSuiteLike.scala:208) at org.scalatest.FunSuite.runTests(FunSuite.scala:1555) at org.scalatest.Suite$class.run(Suite.scala:1424) at org.scalatest.FunSuite.org$scalatest$FunSuiteLike$$super$run(FunSuite.scala:1555) at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:212) at org.scalatest.FunSuiteLike$$anonfun$run$1.apply(FunSuiteLike.scala:212) at org.scalatest.SuperEngine.runImpl(Engine.scala:545) at org.scalatest.FunSuiteLike$class.run(FunSuiteLike.scala:212) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite.org$scalatest$BeforeAndAfterAll$$super$run(HiveMetastoreCatalogSuite.scala:52) at org.scalatest.BeforeAndAfterAll$class.liftedTree1$1(BeforeAndAfterAll.scala:257) at org.scalatest.BeforeAndAfterAll$class.run(BeforeAndAfterAll.scala:256) at org.apache.spark.sql.hive.DataSourceWithHiveMetastoreCatalogSuite.run(HiveMetastoreCatalogSuite.scala:52) at org.scalatest.tools.Framework.org$scalatest$tools$Framework$$runSuite(Framework.scala:462) at org.scalatest.tools.Framework$ScalaTestTask.execute(Framework.scala:671) at sbt.ForkMain$Run$2.call(ForkMain.java:294) at sbt.ForkMain$Run$2.call(ForkMain.java:284) at java.util.concurrent.FutureTask.run(FutureTask.java:262) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) at java.lang.Thread.run(Thread.java:745) Caused by: java.lang.UnsupportedOperationException: Unknown field type: binary at org.apache.hadoop.hive.ql.io.parquet.serde.ArrayWritableObjectInspector.getObjectInspector(ArrayWritableObjectInspector.java:108) at org.apache.hadoop.hive.ql.io.parquet.serde.ArrayWritableObjectInspector.(ArrayWritableObjectInspector.java:60) at org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe.initialize(ParquetHiveSerDe.java:113) at org.apache.hadoop.hive.metastore.MetaStoreUtils.getDeserializer(MetaStoreUtils.java:339) at org.apache.hadoop.hive.ql.metadata.Table.getDeserializerFromMetaStore(Table.java:288) at org.apache.hadoop.hive.ql.metadata.Table.checkValidity(Table.java:194) at org.apache.hadoop.hive.ql.metadata.Hive.createTable(Hive.java:597) ... 76 more ``` Author: Yin Huai Closes #8824 from yhuai/datasourceMetadata. --- .../spark/sql/hive/HiveMetastoreCatalog.scala | 101 +++++++++--------- 1 file changed, 50 insertions(+), 51 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 0c1b41e3377e3..012634cb5aeb5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -309,69 +309,68 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } // TODO: Support persisting partitioned data source relations in Hive compatible format - val hiveTable = (maybeSerDe, dataSource.relation) match { + val qualifiedTableName = tableIdent.quotedString + val (hiveCompitiableTable, logMessage) = (maybeSerDe, dataSource.relation) match { case (Some(serde), relation: HadoopFsRelation) - if relation.paths.length == 1 && relation.partitionColumns.isEmpty => - // Hive ParquetSerDe doesn't support decimal type until 1.2.0. - val isParquetSerDe = serde.inputFormat.exists(_.toLowerCase.contains("parquet")) - val hasDecimalFields = relation.schema.existsRecursively(_.isInstanceOf[DecimalType]) - - val hiveParquetSupportsDecimal = client.version match { - case org.apache.spark.sql.hive.client.hive.v1_2 => true - case _ => false - } - - if (isParquetSerDe && !hiveParquetSupportsDecimal && hasDecimalFields) { - // If Hive version is below 1.2.0, we cannot save Hive compatible schema to - // metastore when the file format is Parquet and the schema has DecimalType. - logWarning { - "Persisting Parquet relation with decimal field(s) into Hive metastore in Spark SQL " + - "specific format, which is NOT compatible with Hive. Because ParquetHiveSerDe in " + - s"Hive ${client.version.fullVersion} doesn't support decimal type. See HIVE-6384." - } - newSparkSQLSpecificMetastoreTable() - } else { - logInfo { - "Persisting data source relation with a single input path into Hive metastore in " + - s"Hive compatible format. Input path: ${relation.paths.head}" - } - newHiveCompatibleMetastoreTable(relation, serde) - } + if relation.paths.length == 1 && relation.partitionColumns.isEmpty => + val hiveTable = newHiveCompatibleMetastoreTable(relation, serde) + val message = + s"Persisting data source relation $qualifiedTableName with a single input path " + + s"into Hive metastore in Hive compatible format. Input path: ${relation.paths.head}." + (Some(hiveTable), message) case (Some(serde), relation: HadoopFsRelation) if relation.partitionColumns.nonEmpty => - logWarning { - "Persisting partitioned data source relation into Hive metastore in " + - s"Spark SQL specific format, which is NOT compatible with Hive. Input path(s): " + - relation.paths.mkString("\n", "\n", "") - } - newSparkSQLSpecificMetastoreTable() + val message = + s"Persisting partitioned data source relation $qualifiedTableName into " + + "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. " + + "Input path(s): " + relation.paths.mkString("\n", "\n", "") + (None, message) case (Some(serde), relation: HadoopFsRelation) => - logWarning { - "Persisting data source relation with multiple input paths into Hive metastore in " + - s"Spark SQL specific format, which is NOT compatible with Hive. Input paths: " + - relation.paths.mkString("\n", "\n", "") - } - newSparkSQLSpecificMetastoreTable() + val message = + s"Persisting data source relation $qualifiedTableName with multiple input paths into " + + "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. " + + s"Input paths: " + relation.paths.mkString("\n", "\n", "") + (None, message) case (Some(serde), _) => - logWarning { - s"Data source relation is not a ${classOf[HadoopFsRelation].getSimpleName}. " + - "Persisting it into Hive metastore in Spark SQL specific format, " + - "which is NOT compatible with Hive." - } - newSparkSQLSpecificMetastoreTable() + val message = + s"Data source relation $qualifiedTableName is not a " + + s"${classOf[HadoopFsRelation].getSimpleName}. Persisting it into Hive metastore " + + "in Spark SQL specific format, which is NOT compatible with Hive." + (None, message) case _ => - logWarning { + val message = s"Couldn't find corresponding Hive SerDe for data source provider $provider. " + - "Persisting data source relation into Hive metastore in Spark SQL specific format, " + - "which is NOT compatible with Hive." - } - newSparkSQLSpecificMetastoreTable() + s"Persisting data source relation $qualifiedTableName into Hive metastore in " + + s"Spark SQL specific format, which is NOT compatible with Hive." + (None, message) } - client.createTable(hiveTable) + (hiveCompitiableTable, logMessage) match { + case (Some(table), message) => + // We first try to save the metadata of the table in a Hive compatiable way. + // If Hive throws an error, we fall back to save its metadata in the Spark SQL + // specific way. + try { + logInfo(message) + client.createTable(table) + } catch { + case throwable: Throwable => + val warningMessage = + s"Could not persist $qualifiedTableName in a Hive compatible way. Persisting " + + s"it into Hive metastore in Spark SQL specific format." + logWarning(warningMessage, throwable) + val sparkSqlSpecificTable = newSparkSQLSpecificMetastoreTable() + client.createTable(sparkSqlSpecificTable) + } + + case (None, message) => + logWarning(message) + val hiveTable = newSparkSQLSpecificMetastoreTable() + client.createTable(hiveTable) + } } def hiveDefaultTableFilePath(tableName: String): String = { From 5aea987c904b281d7952ad8db40a32561b4ec5cf Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 22 Sep 2015 13:31:35 -0700 Subject: [PATCH 390/802] [SPARK-10737] [SQL] When using UnsafeRows, SortMergeJoin may return wrong results https://issues.apache.org/jira/browse/SPARK-10737 Author: Yin Huai Closes #8854 from yhuai/SMJBug. --- .../codegen/GenerateProjection.scala | 2 ++ .../apache/spark/sql/execution/Window.scala | 9 ++++-- .../sql/execution/joins/SortMergeJoin.scala | 25 +++++++++++++++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 28 +++++++++++++++++++ 4 files changed, 59 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 2164ddf03d1b2..75524b568d685 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -171,6 +171,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { @Override public Object apply(Object r) { + // GenerateProjection does not work with UnsafeRows. + assert(!(r instanceof ${classOf[UnsafeRow].getName})); return new SpecificRow((InternalRow) r); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 0269d6d4b7a1c..f8929530c5036 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -253,7 +253,11 @@ case class Window( // Get all relevant projections. val result = createResultProjection(unboundExpressions) - val grouping = newProjection(partitionSpec, child.output) + val grouping = if (child.outputsUnsafeRows) { + UnsafeProjection.create(partitionSpec, child.output) + } else { + newProjection(partitionSpec, child.output) + } // Manage the stream and the grouping. var nextRow: InternalRow = EmptyRow @@ -277,7 +281,8 @@ case class Window( val numFrames = frames.length private[this] def fetchNextPartition() { // Collect all the rows in the current partition. - val currentGroup = nextGroup + // Before we start to fetch new input rows, make a copy of nextGroup. + val currentGroup = nextGroup.copy() rows = new CompactBuffer while (nextRowAvailable && nextGroup == currentGroup) { rows += nextRow.copy() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 906f20d2a7289..70a1af6a7063a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -56,9 +56,6 @@ case class SortMergeJoin( override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil - @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) - @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) - protected[this] def isUnsafeMode: Boolean = { (codegenEnabled && unsafeEnabled && UnsafeProjection.canSupport(leftKeys) @@ -82,6 +79,28 @@ case class SortMergeJoin( left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => new RowIterator { + // The projection used to extract keys from input rows of the left child. + private[this] val leftKeyGenerator = { + if (isUnsafeMode) { + // It is very important to use UnsafeProjection if input rows are UnsafeRows. + // Otherwise, GenerateProjection will cause wrong results. + UnsafeProjection.create(leftKeys, left.output) + } else { + newProjection(leftKeys, left.output) + } + } + + // The projection used to extract keys from input rows of the right child. + private[this] val rightKeyGenerator = { + if (isUnsafeMode) { + // It is very important to use UnsafeProjection if input rows are UnsafeRows. + // Otherwise, GenerateProjection will cause wrong results. + UnsafeProjection.create(rightKeys, right.output) + } else { + newProjection(rightKeys, right.output) + } + } + // An ordering that can be used to compare keys from both sides. private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) private[this] var currentLeftRow: InternalRow = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 05b4127cbcaff..eca6f1073889a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1781,4 +1781,32 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Seq(Row(1), Row(1))) } } + + test("SortMergeJoin returns wrong results when using UnsafeRows") { + // This test is for the fix of https://issues.apache.org/jira/browse/SPARK-10737. + // This bug will be triggered when Tungsten is enabled and there are multiple + // SortMergeJoin operators executed in the same task. + val confs = + SQLConf.SORTMERGE_JOIN.key -> "true" :: + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1" :: + SQLConf.TUNGSTEN_ENABLED.key -> "true" :: Nil + withSQLConf(confs: _*) { + val df1 = (1 to 50).map(i => (s"str_$i", i)).toDF("i", "j") + val df2 = + df1 + .join(df1.select(df1("i")), "i") + .select(df1("i"), df1("j")) + + val df3 = df2.withColumnRenamed("i", "i1").withColumnRenamed("j", "j1") + val df4 = + df2 + .join(df3, df2("i") === df3("i1")) + .withColumn("diff", $"j" - $"j1") + .select(df2("i"), df2("j"), $"diff") + + checkAnswer( + df4, + df1.withColumn("diff", lit(0))) + } + } } From a96ba40f7ee1352288ea676d8844e1c8174202eb Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 22 Sep 2015 14:11:46 -0700 Subject: [PATCH 391/802] [SPARK-10714] [SPARK-8632] [SPARK-10685] [SQL] Refactor Python UDF handling This patch refactors Python UDF handling: 1. Extract the per-partition Python UDF calling logic from PythonRDD into a PythonRunner. PythonRunner itself expects iterator as input/output, and thus has no dependency on RDD. This way, we can use PythonRunner directly in a mapPartitions call, or in the future in an environment without RDDs. 2. Use PythonRunner in Spark SQL's BatchPythonEvaluation. 3. Updated BatchPythonEvaluation to only use its input once, rather than twice. This should fix Python UDF performance regression in Spark 1.5. There are a number of small cleanups I wanted to do when I looked at the code, but I kept most of those out so the diff looks small. This basically implements the approach in https://github.com/apache/spark/pull/8833, but with some code moving around so the correctness doesn't depend on the inner workings of Spark serialization and task execution. Author: Reynold Xin Closes #8835 from rxin/python-iter-refactor. --- .../apache/spark/api/python/PythonRDD.scala | 54 ++++++++++--- .../spark/sql/execution/pythonUDFs.scala | 80 +++++++++++-------- 2 files changed, 89 insertions(+), 45 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 69da180593bb5..3788d1829758a 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -24,6 +24,7 @@ import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JM import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.existentials +import scala.util.control.NonFatal import com.google.common.base.Charsets.UTF_8 import org.apache.hadoop.conf.Configuration @@ -38,7 +39,6 @@ import org.apache.spark.input.PortableDataStream import org.apache.spark.rdd.RDD import org.apache.spark.util.{SerializableConfiguration, Utils} -import scala.util.control.NonFatal private[spark] class PythonRDD( parent: RDD[_], @@ -61,11 +61,39 @@ private[spark] class PythonRDD( if (preservePartitoning) firstParent.partitioner else None } + val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) + override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { + val runner = new PythonRunner( + command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars, accumulator, + bufferSize, reuse_worker) + runner.compute(firstParent.iterator(split, context), split.index, context) + } +} + + +/** + * A helper class to run Python UDFs in Spark. + */ +private[spark] class PythonRunner( + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + pythonVer: String, + broadcastVars: JList[Broadcast[PythonBroadcast]], + accumulator: Accumulator[JList[Array[Byte]]], + bufferSize: Int, + reuse_worker: Boolean) + extends Logging { + + def compute( + inputIterator: Iterator[_], + partitionIndex: Int, + context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis val env = SparkEnv.get - val localdir = env.blockManager.diskBlockManager.localDirs.map( - f => f.getPath()).mkString(",") + val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread if (reuse_worker) { envVars.put("SPARK_REUSE_WORKER", "1") @@ -75,7 +103,7 @@ private[spark] class PythonRDD( @volatile var released = false // Start a thread to feed the process input from our parent's iterator - val writerThread = new WriterThread(env, worker, split, context) + val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context) context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() @@ -183,13 +211,16 @@ private[spark] class PythonRDD( new InterruptibleIterator(context, stdoutIterator) } - val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) - /** * The thread responsible for writing the data from the PythonRDD's parent iterator to the * Python process. */ - class WriterThread(env: SparkEnv, worker: Socket, split: Partition, context: TaskContext) + class WriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[_], + partitionIndex: Int, + context: TaskContext) extends Thread(s"stdout writer for $pythonExec") { @volatile private var _exception: Exception = null @@ -211,11 +242,11 @@ private[spark] class PythonRDD( val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) // Partition index - dataOut.writeInt(split.index) + dataOut.writeInt(partitionIndex) // Python version of driver PythonRDD.writeUTF(pythonVer, dataOut) // sparkFilesDir - PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) + PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) // Python includes (*.zip and *.egg files) dataOut.writeInt(pythonIncludes.size()) for (include <- pythonIncludes.asScala) { @@ -246,7 +277,7 @@ private[spark] class PythonRDD( dataOut.writeInt(command.length) dataOut.write(command) // Data values - PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut) + PythonRDD.writeIteratorToStream(inputIterator, dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) dataOut.writeInt(SpecialLengths.END_OF_STREAM) dataOut.flush() @@ -327,7 +358,8 @@ private[spark] object PythonRDD extends Logging { // remember the broadcasts sent to each worker private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]() - private def getWorkerBroadcasts(worker: Socket) = { + + def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = { synchronized { workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index d0411da6fdf5a..c35c726bfc503 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.api.python.{PythonBroadcast, PythonRDD, SerDeUtil} +import org.apache.spark.api.python.{PythonRunner, PythonBroadcast, PythonRDD, SerDeUtil} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.{Accumulator, Logging => SparkLogging} +import org.apache.spark.{Logging => SparkLogging, TaskContext, Accumulator} /** * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]]. @@ -329,7 +329,13 @@ case class EvaluatePython( /** * :: DeveloperApi :: * Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a time. - * The input data is zipped with the result of the udf evaluation. + * + * Python evaluation works by sending the necessary (projected) input data via a socket to an + * external Python process, and combine the result from the Python process with the original row. + * + * For each row we send to Python, we also put it in a queue. For each output row from Python, + * we drain the queue to find the original input row. Note that if the Python process is way too + * slow, this could lead to the queue growing unbounded and eventually run out of memory. */ @DeveloperApi case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) @@ -342,51 +348,57 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: override def canProcessSafeRows: Boolean = true protected override def doExecute(): RDD[InternalRow] = { - val childResults = child.execute().map(_.copy()) + val inputRDD = child.execute().map(_.copy()) + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) - val parent = childResults.mapPartitions { iter => + inputRDD.mapPartitions { iter => EvaluatePython.registerPicklers() // register pickler for Row + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() + val pickle = new Pickler val currentRow = newMutableProjection(udf.children, child.output)() val fields = udf.children.map(_.dataType) val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) - iter.grouped(100).map { inputRows => + + // Input iterator to Python: input rows are grouped so we send them in batches to Python. + // For each row, add it to the queue. + val inputIterator = iter.grouped(100).map { inputRows => val toBePickled = inputRows.map { row => + queue.add(row) EvaluatePython.toJava(currentRow(row), schema) }.toArray pickle.dumps(toBePickled) } - } - val pyRDD = new PythonRDD( - parent, - udf.command, - udf.envVars, - udf.pythonIncludes, - false, - udf.pythonExec, - udf.pythonVer, - udf.broadcastVars, - udf.accumulator - ).mapPartitions { iter => - val pickle = new Unpickler - iter.flatMap { pickedResult => - val unpickledBatch = pickle.loads(pickedResult) - unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala - } - }.mapPartitions { iter => + val context = TaskContext.get() + + // Output iterator for results from Python. + val outputIterator = new PythonRunner( + udf.command, + udf.envVars, + udf.pythonIncludes, + udf.pythonExec, + udf.pythonVer, + udf.broadcastVars, + udf.accumulator, + bufferSize, + reuseWorker + ).compute(inputIterator, context.partitionId(), context) + + val unpickle = new Unpickler val row = new GenericMutableRow(1) - iter.map { result => - row(0) = EvaluatePython.fromJava(result, udf.dataType) - row: InternalRow - } - } + val joined = new JoinedRow - childResults.zip(pyRDD).mapPartitions { iter => - val joinedRow = new JoinedRow() - iter.map { - case (row, udfResult) => - joinedRow(row, udfResult) + outputIterator.flatMap { pickedResult => + val unpickledBatch = unpickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala + }.map { result => + row(0) = EvaluatePython.fromJava(result, udf.dataType) + joined(queue.poll(), row) } } } From 61d4c07f4becb42f054e588be56ed13239644410 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 22 Sep 2015 16:35:43 -0700 Subject: [PATCH 392/802] [SPARK-10640] History server fails to parse TaskCommitDenied ... simply because the code is missing! Author: Andrew Or Closes #8828 from andrewor14/task-end-reason-json. --- .../scala/org/apache/spark/TaskEndReason.scala | 6 +++++- .../org/apache/spark/util/JsonProtocol.scala | 13 +++++++++++++ .../apache/spark/util/JsonProtocolSuite.scala | 17 +++++++++++++++++ 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 7137246bc34f2..9335c5f4160bf 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -17,13 +17,17 @@ package org.apache.spark -import java.io.{IOException, ObjectInputStream, ObjectOutputStream} +import java.io.{ObjectInputStream, ObjectOutputStream} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils +// ============================================================================================== +// NOTE: new task end reasons MUST be accompanied with serialization logic in util.JsonProtocol! +// ============================================================================================== + /** * :: DeveloperApi :: * Various possible reasons why a task ended. The low-level TaskScheduler is supposed to retry diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 99614a786bd93..40729fa5a4ffe 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -362,6 +362,10 @@ private[spark] object JsonProtocol { ("Stack Trace" -> stackTrace) ~ ("Full Stack Trace" -> exceptionFailure.fullStackTrace) ~ ("Metrics" -> metrics) + case taskCommitDenied: TaskCommitDenied => + ("Job ID" -> taskCommitDenied.jobID) ~ + ("Partition ID" -> taskCommitDenied.partitionID) ~ + ("Attempt Number" -> taskCommitDenied.attemptNumber) case ExecutorLostFailure(executorId, isNormalExit) => ("Executor ID" -> executorId) ~ ("Normal Exit" -> isNormalExit) @@ -770,6 +774,7 @@ private[spark] object JsonProtocol { val exceptionFailure = Utils.getFormattedClassName(ExceptionFailure) val taskResultLost = Utils.getFormattedClassName(TaskResultLost) val taskKilled = Utils.getFormattedClassName(TaskKilled) + val taskCommitDenied = Utils.getFormattedClassName(TaskCommitDenied) val executorLostFailure = Utils.getFormattedClassName(ExecutorLostFailure) val unknownReason = Utils.getFormattedClassName(UnknownReason) @@ -794,6 +799,14 @@ private[spark] object JsonProtocol { ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics, None) case `taskResultLost` => TaskResultLost case `taskKilled` => TaskKilled + case `taskCommitDenied` => + // Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON + // de/serialization logic was not added until 1.5.1. To provide backward compatibility + // for reading those logs, we need to provide default values for all the fields. + val jobId = Utils.jsonOption(json \ "Job ID").map(_.extract[Int]).getOrElse(-1) + val partitionId = Utils.jsonOption(json \ "Partition ID").map(_.extract[Int]).getOrElse(-1) + val attemptNo = Utils.jsonOption(json \ "Attempt Number").map(_.extract[Int]).getOrElse(-1) + TaskCommitDenied(jobId, partitionId, attemptNo) case `executorLostFailure` => val isNormalExit = Utils.jsonOption(json \ "Normal Exit"). map(_.extract[Boolean]) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 143c1b901df11..a24bf2931cca0 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -151,6 +151,7 @@ class JsonProtocolSuite extends SparkFunSuite { testTaskEndReason(exceptionFailure) testTaskEndReason(TaskResultLost) testTaskEndReason(TaskKilled) + testTaskEndReason(TaskCommitDenied(2, 3, 4)) testTaskEndReason(ExecutorLostFailure("100", true)) testTaskEndReason(UnknownReason) @@ -352,6 +353,17 @@ class JsonProtocolSuite extends SparkFunSuite { assertEquals(expectedStageInfo, JsonProtocol.stageInfoFromJson(oldStageInfo)) } + // `TaskCommitDenied` was added in 1.3.0 but JSON de/serialization logic was added in 1.5.1 + test("TaskCommitDenied backward compatibility") { + val denied = TaskCommitDenied(1, 2, 3) + val oldDenied = JsonProtocol.taskEndReasonToJson(denied) + .removeField({ _._1 == "Job ID" }) + .removeField({ _._1 == "Partition ID" }) + .removeField({ _._1 == "Attempt Number" }) + val expectedDenied = TaskCommitDenied(-1, -1, -1) + assertEquals(expectedDenied, JsonProtocol.taskEndReasonFromJson(oldDenied)) + } + /** -------------------------- * | Helper test running methods | * --------------------------- */ @@ -577,6 +589,11 @@ class JsonProtocolSuite extends SparkFunSuite { assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals) case (TaskResultLost, TaskResultLost) => case (TaskKilled, TaskKilled) => + case (TaskCommitDenied(jobId1, partitionId1, attemptNumber1), + TaskCommitDenied(jobId2, partitionId2, attemptNumber2)) => + assert(jobId1 === jobId2) + assert(partitionId1 === partitionId2) + assert(attemptNumber1 === attemptNumber2) case (ExecutorLostFailure(execId1, isNormalExit1), ExecutorLostFailure(execId2, isNormalExit2)) => assert(execId1 === execId2) From 84f81e035e1dab1b42c36563041df6ba16e7b287 Mon Sep 17 00:00:00 2001 From: Zhichao Li Date: Tue, 22 Sep 2015 19:41:57 -0700 Subject: [PATCH 393/802] [SPARK-10310] [SQL] Fixes script transformation field/line delimiters **Please attribute this PR to `Zhichao Li `.** This PR is based on PR #8476 authored by zhichao-li. It fixes SPARK-10310 by adding field delimiter SerDe property to the default `LazySimpleSerDe`, and enabling default record reader/writer classes. Currently, we only support `LazySimpleSerDe`, used together with `TextRecordReader` and `TextRecordWriter`, and don't support customizing record reader/writer using `RECORDREADER`/`RECORDWRITER` clauses. This should be addressed in separate PR(s). Author: Cheng Lian Closes #8860 from liancheng/spark-10310/fix-script-trans-delimiters. --- .../org/apache/spark/sql/hive/HiveQl.scala | 52 ++++++++++--- .../hive/execution/ScriptTransformation.scala | 75 +++++++++++++++---- .../resources/data/scripts/test_transform.py | 6 ++ .../sql/hive/execution/SQLQuerySuite.scala | 39 ++++++++++ .../execution/ScriptTransformationSuite.scala | 2 + 5 files changed, 152 insertions(+), 22 deletions(-) create mode 100755 sql/hive/src/test/resources/data/scripts/test_transform.py diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index d5cd7e98b5267..256440a9a2e97 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -32,6 +32,7 @@ import org.apache.hadoop.hive.ql.lib.Node import org.apache.hadoop.hive.ql.parse._ import org.apache.hadoop.hive.ql.plan.PlanUtils import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException @@ -884,16 +885,22 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C AttributeReference("value", StringType)()), true) } - def matchSerDe(clause: Seq[ASTNode]) - : (Seq[(String, String)], Option[String], Seq[(String, String)]) = clause match { + type SerDeInfo = ( + Seq[(String, String)], // Input row format information + Option[String], // Optional input SerDe class + Seq[(String, String)], // Input SerDe properties + Boolean // Whether to use default record reader/writer + ) + + def matchSerDe(clause: Seq[ASTNode]): SerDeInfo = clause match { case Token("TOK_SERDEPROPS", propsClause) :: Nil => val rowFormat = propsClause.map { case Token(name, Token(value, Nil) :: Nil) => (name, value) } - (rowFormat, None, Nil) + (rowFormat, None, Nil, false) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil => - (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil) + (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil, false) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Token("TOK_TABLEPROPERTIES", @@ -903,20 +910,47 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C (BaseSemanticAnalyzer.unescapeSQLString(name), BaseSemanticAnalyzer.unescapeSQLString(value)) } - (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps) - case Nil => (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), Nil) + // SPARK-10310: Special cases LazySimpleSerDe + // TODO Fully supports user-defined record reader/writer classes + val unescapedSerDeClass = BaseSemanticAnalyzer.unescapeSQLString(serdeClass) + val useDefaultRecordReaderWriter = + unescapedSerDeClass == classOf[LazySimpleSerDe].getCanonicalName + (Nil, Some(unescapedSerDeClass), serdeProps, useDefaultRecordReaderWriter) + + case Nil => + // Uses default TextRecordReader/TextRecordWriter, sets field delimiter here + val serdeProps = Seq(serdeConstants.FIELD_DELIM -> "\t") + (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), serdeProps, true) } - val (inRowFormat, inSerdeClass, inSerdeProps) = matchSerDe(inputSerdeClause) - val (outRowFormat, outSerdeClass, outSerdeProps) = matchSerDe(outputSerdeClause) + val (inRowFormat, inSerdeClass, inSerdeProps, useDefaultRecordReader) = + matchSerDe(inputSerdeClause) + + val (outRowFormat, outSerdeClass, outSerdeProps, useDefaultRecordWriter) = + matchSerDe(outputSerdeClause) val unescapedScript = BaseSemanticAnalyzer.unescapeSQLString(script) + // TODO Adds support for user-defined record reader/writer classes + val recordReaderClass = if (useDefaultRecordReader) { + Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDREADER)) + } else { + None + } + + val recordWriterClass = if (useDefaultRecordWriter) { + Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDWRITER)) + } else { + None + } + val schema = HiveScriptIOSchema( inRowFormat, outRowFormat, inSerdeClass, outSerdeClass, - inSerdeProps, outSerdeProps, schemaLess) + inSerdeProps, outSerdeProps, + recordReaderClass, recordWriterClass, + schemaLess) Some( logical.ScriptTransformation( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 32bddbaeaeaf9..b30117f0de997 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -24,20 +24,22 @@ import javax.annotation.Nullable import scala.collection.JavaConverters._ import scala.util.control.NonFatal +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.ql.exec.{RecordReader, RecordWriter} import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.AbstractSerDe import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.io.Writable import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} import org.apache.spark.sql.types.DataType -import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils} +import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils} import org.apache.spark.{Logging, TaskContext} /** @@ -58,6 +60,8 @@ case class ScriptTransformation( override def otherCopyArgs: Seq[HiveContext] = sc :: Nil + private val serializedHiveConf = new SerializableConfiguration(sc.hiveconf) + protected override def doExecute(): RDD[InternalRow] = { def processIterator(inputIterator: Iterator[InternalRow]): Iterator[InternalRow] = { val cmd = List("/bin/bash", "-c", script) @@ -67,6 +71,7 @@ case class ScriptTransformation( val inputStream = proc.getInputStream val outputStream = proc.getOutputStream val errorStream = proc.getErrorStream + val localHiveConf = serializedHiveConf.value // In order to avoid deadlocks, we need to consume the error output of the child process. // To avoid issues caused by large error output, we use a circular buffer to limit the amount @@ -96,7 +101,8 @@ case class ScriptTransformation( outputStream, proc, stderrBuffer, - TaskContext.get() + TaskContext.get(), + localHiveConf ) // This nullability is a performance optimization in order to avoid an Option.foreach() call @@ -109,6 +115,10 @@ case class ScriptTransformation( val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { var curLine: String = null val scriptOutputStream = new DataInputStream(inputStream) + + @Nullable val scriptOutputReader = + ioschema.recordReader(scriptOutputStream, localHiveConf).orNull + var scriptOutputWritable: Writable = null val reusedWritableObject: Writable = if (null != outputSerde) { outputSerde.getSerializedClass().newInstance @@ -134,15 +144,25 @@ case class ScriptTransformation( } } else if (scriptOutputWritable == null) { scriptOutputWritable = reusedWritableObject - try { - scriptOutputWritable.readFields(scriptOutputStream) - true - } catch { - case _: EOFException => - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } + + if (scriptOutputReader != null) { + if (scriptOutputReader.next(scriptOutputWritable) <= 0) { + writerThread.exception.foreach(throw _) false + } else { + true + } + } else { + try { + scriptOutputWritable.readFields(scriptOutputStream) + true + } catch { + case _: EOFException => + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + false + } } } else { true @@ -210,7 +230,8 @@ private class ScriptTransformationWriterThread( outputStream: OutputStream, proc: Process, stderrBuffer: CircularBuffer, - taskContext: TaskContext + taskContext: TaskContext, + conf: Configuration ) extends Thread("Thread-ScriptTransformation-Feed") with Logging { setDaemon(true) @@ -224,6 +245,7 @@ private class ScriptTransformationWriterThread( TaskContext.setTaskContext(taskContext) val dataOutputStream = new DataOutputStream(outputStream) + @Nullable val scriptInputWriter = ioschema.recordWriter(dataOutputStream, conf).orNull // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so // let's use a variable to record whether the `finally` block was hit due to an exception @@ -250,7 +272,12 @@ private class ScriptTransformationWriterThread( } else { val writable = inputSerde.serialize( row.asInstanceOf[GenericInternalRow].values, inputSoi) - prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream) + + if (scriptInputWriter != null) { + scriptInputWriter.write(writable) + } else { + prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream) + } } } outputStream.close() @@ -290,6 +317,8 @@ case class HiveScriptIOSchema ( outputSerdeClass: Option[String], inputSerdeProps: Seq[(String, String)], outputSerdeProps: Seq[(String, String)], + recordReaderClass: Option[String], + recordWriterClass: Option[String], schemaLess: Boolean) extends ScriptInputOutputSchema with HiveInspectors { private val defaultFormat = Map( @@ -347,4 +376,24 @@ case class HiveScriptIOSchema ( serde } + + def recordReader( + inputStream: InputStream, + conf: Configuration): Option[RecordReader] = { + recordReaderClass.map { klass => + val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordReader] + val props = new Properties() + props.putAll(outputSerdeProps.toMap.asJava) + instance.initialize(inputStream, conf, props) + instance + } + } + + def recordWriter(outputStream: OutputStream, conf: Configuration): Option[RecordWriter] = { + recordWriterClass.map { klass => + val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordWriter] + instance.initialize(outputStream, conf) + instance + } + } } diff --git a/sql/hive/src/test/resources/data/scripts/test_transform.py b/sql/hive/src/test/resources/data/scripts/test_transform.py new file mode 100755 index 0000000000000..ac6d11d8b919c --- /dev/null +++ b/sql/hive/src/test/resources/data/scripts/test_transform.py @@ -0,0 +1,6 @@ +import sys + +delim = sys.argv[1] + +for row in sys.stdin: + print(delim.join([w + '#' for w in row[:-1].split(delim)])) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index bb02473dd17ca..71823e32ad389 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1184,4 +1184,43 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(df, Row("text inside layer 2") :: Nil) } + + test("SPARK-10310: " + + "script transformation using default input/output SerDe and record reader/writer") { + sqlContext + .range(5) + .selectExpr("id AS a", "id AS b") + .registerTempTable("test") + + checkAnswer( + sql( + """FROM( + | FROM test SELECT TRANSFORM(a, b) + | USING 'python src/test/resources/data/scripts/test_transform.py "\t"' + | AS (c STRING, d STRING) + |) t + |SELECT c + """.stripMargin), + (0 until 5).map(i => Row(i + "#"))) + } + + test("SPARK-10310: script transformation using LazySimpleSerDe") { + sqlContext + .range(5) + .selectExpr("id AS a", "id AS b") + .registerTempTable("test") + + val df = sql( + """FROM test + |SELECT TRANSFORM(a, b) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + |WITH SERDEPROPERTIES('field.delim' = '|') + |USING 'python src/test/resources/data/scripts/test_transform.py "|"' + |AS (c STRING, d STRING) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + |WITH SERDEPROPERTIES('field.delim' = '|') + """.stripMargin) + + checkAnswer(df, (0 until 5).map(i => Row(i + "#", i + "#"))) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index cb8d0fca8e693..7cfdb886b585d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -38,6 +38,8 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { outputSerdeClass = None, inputSerdeProps = Seq.empty, outputSerdeProps = Seq.empty, + recordReaderClass = None, + recordWriterClass = None, schemaLess = false ) From 558e9c7e60a7c0d85ba26634e97562ad2163e91d Mon Sep 17 00:00:00 2001 From: Matt Hagen Date: Tue, 22 Sep 2015 21:14:25 -0700 Subject: [PATCH 394/802] [SPARK-10663] Removed unnecessary invocation of DataFrame.toDF method. The Scala example under the "Example: Pipeline" heading in this document initializes the "test" variable to a DataFrame. Because test is already a DF, there is not need to call test.toDF as the example does in a subsequent line: model.transform(test.toDF). So, I removed the extraneous toDF invocation. Author: Matt Hagen Closes #8875 from hagenhaus/SPARK-10663. --- docs/ml-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 0427ac6695aa1..fd3a6167bc65e 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -475,7 +475,7 @@ val test = sqlContext.createDataFrame(Seq( )).toDF("id", "text") // Make predictions on test documents. -model.transform(test.toDF) +model.transform(test) .select("id", "text", "probability", "prediction") .collect() .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => From 5548a254755bb84edae2768b94ab1816e1b49b91 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 22 Sep 2015 22:44:09 -0700 Subject: [PATCH 395/802] [SPARK-10652] [SPARK-10742] [STREAMING] Set meaningful job descriptions for all streaming jobs Here is the screenshot after adding the job descriptions to threads that run receivers and the scheduler thread running the batch jobs. ## All jobs page * Added job descriptions with links to relevant batch details page ![image](https://cloud.githubusercontent.com/assets/663212/9924165/cda4a372-5cb1-11e5-91ca-d43a32c699e9.png) ## All stages page * Added stage descriptions with links to relevant batch details page ![image](https://cloud.githubusercontent.com/assets/663212/9923814/2cce266a-5cae-11e5-8a3f-dad84d06c50e.png) ## Streaming batch details page * Added the +details link ![image](https://cloud.githubusercontent.com/assets/663212/9921977/24014a32-5c98-11e5-958e-457b6c38065b.png) Author: Tathagata Das Closes #8791 from tdas/SPARK-10652. --- .../scala/org/apache/spark/ui/UIUtils.scala | 62 ++++++++++++++++- .../apache/spark/ui/jobs/AllJobsPage.scala | 14 ++-- .../org/apache/spark/ui/jobs/StageTable.scala | 7 +- .../org/apache/spark/ui/UIUtilsSuite.scala | 66 +++++++++++++++++++ .../spark/streaming/StreamingContext.scala | 4 +- .../streaming/scheduler/JobScheduler.scala | 15 ++++- .../streaming/scheduler/ReceiverTracker.scala | 5 +- .../apache/spark/streaming/ui/BatchPage.scala | 33 ++++++---- .../streaming/StreamingContextSuite.scala | 2 +- 9 files changed, 179 insertions(+), 29 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index f2da417724104..21dc8f0b65485 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -18,9 +18,11 @@ package org.apache.spark.ui import java.text.SimpleDateFormat -import java.util.{Locale, Date} +import java.util.{Date, Locale} -import scala.xml.{Node, Text, Unparsed} +import scala.util.control.NonFatal +import scala.xml._ +import scala.xml.transform.{RewriteRule, RuleTransformer} import org.apache.spark.Logging import org.apache.spark.ui.scope.RDDOperationGraph @@ -395,4 +397,60 @@ private[spark] object UIUtils extends Logging { } + /** + * Returns HTML rendering of a job or stage description. It will try to parse the string as HTML + * and make sure that it only contains anchors with root-relative links. Otherwise, + * the whole string will rendered as a simple escaped text. + * + * Note: In terms of security, only anchor tags with root relative links are supported. So any + * attempts to embed links outside Spark UI, or other tags like