From d341b17c2a0a4fce04045e13fb4a3b0621296320 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 4 Jun 2014 11:27:08 -0700 Subject: [PATCH 01/18] SPARK-1973. Add randomSplit to JavaRDD (with tests, and tidy Java tests) I'd like to use randomSplit through the Java API, and would like to add a convenience wrapper for this method to JavaRDD. This is fairly trivial. (In fact, is the intent that JavaRDD not wrap every RDD method? and that sometimes users should just use JavaRDD.wrapRDD()?) Along the way, I added tests for it, and also touched up the Java API test style and behavior. This is maybe the more useful part of this small change. Author: Sean Owen Author: Xiangrui Meng This patch had conflicts when merged, resolved by Committer: Xiangrui Meng Closes #919 from srowen/SPARK-1973 and squashes the following commits: 148cb7b [Sean Owen] Some final Java test polish, while we are at it 1fc3f3e [Xiangrui Meng] more cleaning on Java 8 tests 9ebc57f [Sean Owen] Use accumulator instead of temp files to test foreach 5efb0be [Sean Owen] Add Java randomSplit, and unit tests (including for sample) 5dcc158 [Sean Owen] Simplified Java 8 test with new language features, and fixed the name of MLB's greatest team 91a1769 [Sean Owen] Touch up minor style issues in existing Java API suite test --- .../org/apache/spark/api/java/JavaRDD.scala | 22 + .../java/org/apache/spark/JavaAPISuite.java | 193 ++++----- .../java/org/apache/spark/Java8APISuite.java | 96 +++-- .../apache/spark/streaming/Java8APISuite.java | 381 +++++++++--------- 4 files changed, 358 insertions(+), 334 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index dc698dea75e43..23d13710794af 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -108,6 +108,28 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaRDD[T] = wrapRDD(rdd.sample(withReplacement, fraction, seed)) + + /** + * Randomly splits this RDD with the provided weights. + * + * @param weights weights for splits, will be normalized if they don't sum to 1 + * + * @return split RDDs in an array + */ + def randomSplit(weights: Array[Double]): Array[JavaRDD[T]] = + randomSplit(weights, Utils.random.nextLong) + + /** + * Randomly splits this RDD with the provided weights. + * + * @param weights weights for splits, will be normalized if they don't sum to 1 + * @param seed random seed + * + * @return split RDDs in an array + */ + def randomSplit(weights: Array[Double], seed: Long): Array[JavaRDD[T]] = + rdd.randomSplit(weights, seed).map(wrapRDD) + /** * Return the union of this RDD and another one. Any identical elements will appear multiple * times (use `.distinct()` to eliminate them). diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index b78309f81cb8c..50a62129116f1 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -23,6 +23,7 @@ import scala.Tuple2; import com.google.common.collect.Iterables; +import com.google.common.collect.Iterators; import com.google.common.collect.Lists; import com.google.common.base.Optional; import com.google.common.base.Charsets; @@ -48,7 +49,6 @@ import org.apache.spark.partial.PartialResult; import org.apache.spark.storage.StorageLevel; import org.apache.spark.util.StatCounter; -import org.apache.spark.util.Utils; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; @@ -70,16 +70,6 @@ public void tearDown() { sc = null; } - static class ReverseIntComparator implements Comparator, Serializable { - - @Override - public int compare(Integer a, Integer b) { - if (a > b) return -1; - else if (a < b) return 1; - else return 0; - } - } - @SuppressWarnings("unchecked") @Test public void sparkContextUnion() { @@ -124,7 +114,7 @@ public void intersection() { JavaRDD intersections = s1.intersection(s2); Assert.assertEquals(3, intersections.count()); - ArrayList list = new ArrayList(); + List list = new ArrayList(); JavaRDD empty = sc.parallelize(list); JavaRDD emptyIntersection = empty.intersection(s2); Assert.assertEquals(0, emptyIntersection.count()); @@ -144,6 +134,28 @@ public void intersection() { Assert.assertEquals(2, pIntersection.count()); } + @Test + public void sample() { + List ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + JavaRDD rdd = sc.parallelize(ints); + JavaRDD sample20 = rdd.sample(true, 0.2, 11); + // expected 2 but of course result varies randomly a bit + Assert.assertEquals(3, sample20.count()); + JavaRDD sample20NoReplacement = rdd.sample(false, 0.2, 11); + Assert.assertEquals(2, sample20NoReplacement.count()); + } + + @Test + public void randomSplit() { + List ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + JavaRDD rdd = sc.parallelize(ints); + JavaRDD[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 11); + Assert.assertEquals(3, splits.length); + Assert.assertEquals(2, splits[0].count()); + Assert.assertEquals(3, splits[1].count()); + Assert.assertEquals(5, splits[2].count()); + } + @Test public void sortByKey() { List> pairs = new ArrayList>(); @@ -161,26 +173,24 @@ public void sortByKey() { Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); // Custom comparator - sortedRDD = rdd.sortByKey(new ReverseIntComparator(), false); + sortedRDD = rdd.sortByKey(Collections.reverseOrder(), false); Assert.assertEquals(new Tuple2(-1, 1), sortedRDD.first()); sortedPairs = sortedRDD.collect(); Assert.assertEquals(new Tuple2(0, 4), sortedPairs.get(1)); Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); } - static int foreachCalls = 0; - @Test public void foreach() { - foreachCalls = 0; + final Accumulator accum = sc.accumulator(0); JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); rdd.foreach(new VoidFunction() { @Override - public void call(String s) { - foreachCalls++; + public void call(String s) throws IOException { + accum.add(1); } }); - Assert.assertEquals(2, foreachCalls); + Assert.assertEquals(2, accum.value().intValue()); } @Test @@ -188,7 +198,7 @@ public void toLocalIterator() { List correct = Arrays.asList(1, 2, 3, 4); JavaRDD rdd = sc.parallelize(correct); List result = Lists.newArrayList(rdd.toLocalIterator()); - Assert.assertTrue(correct.equals(result)); + Assert.assertEquals(correct, result); } @Test @@ -196,7 +206,7 @@ public void zipWithUniqueId() { List dataArray = Arrays.asList(1, 2, 3, 4); JavaPairRDD zip = sc.parallelize(dataArray).zipWithUniqueId(); JavaRDD indexes = zip.values(); - Assert.assertTrue(new HashSet(indexes.collect()).size() == 4); + Assert.assertEquals(4, new HashSet(indexes.collect()).size()); } @Test @@ -205,7 +215,7 @@ public void zipWithIndex() { JavaPairRDD zip = sc.parallelize(dataArray).zipWithIndex(); JavaRDD indexes = zip.values(); List correctIndexes = Arrays.asList(0L, 1L, 2L, 3L); - Assert.assertTrue(indexes.collect().equals(correctIndexes)); + Assert.assertEquals(correctIndexes, indexes.collect()); } @SuppressWarnings("unchecked") @@ -252,8 +262,10 @@ public void cogroup() { new Tuple2("Oranges", 2), new Tuple2("Apples", 3) )); - JavaPairRDD, Iterable>> cogrouped = categories.cogroup(prices); - Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + JavaPairRDD, Iterable>> cogrouped = + categories.cogroup(prices); + Assert.assertEquals("[Fruit, Citrus]", + Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); cogrouped.collect(); @@ -281,8 +293,7 @@ public void leftOuterJoin() { rdd1.leftOuterJoin(rdd2).filter( new Function>>, Boolean>() { @Override - public Boolean call(Tuple2>> tup) - throws Exception { + public Boolean call(Tuple2>> tup) { return !tup._2()._2().isPresent(); } }).first(); @@ -356,8 +367,7 @@ public Integer call(Integer a, Integer b) { Assert.assertEquals(2, localCounts.get(2).intValue()); Assert.assertEquals(3, localCounts.get(3).intValue()); - localCounts = rdd.reduceByKeyLocally(new Function2() { + localCounts = rdd.reduceByKeyLocally(new Function2() { @Override public Integer call(Integer a, Integer b) { return a + b; @@ -448,16 +458,17 @@ public void map() { JavaDoubleRDD doubles = rdd.mapToDouble(new DoubleFunction() { @Override public double call(Integer x) { - return 1.0 * x; + return x.doubleValue(); } }).cache(); doubles.collect(); - JavaPairRDD pairs = rdd.mapToPair(new PairFunction() { - @Override - public Tuple2 call(Integer x) { - return new Tuple2(x, x); - } - }).cache(); + JavaPairRDD pairs = rdd.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(Integer x) { + return new Tuple2(x, x); + } + }).cache(); pairs.collect(); JavaRDD strings = rdd.map(new Function() { @Override @@ -487,7 +498,9 @@ public Iterable call(String x) { @Override public Iterable> call(String s) { List> pairs = new LinkedList>(); - for (String word : s.split(" ")) pairs.add(new Tuple2(word, word)); + for (String word : s.split(" ")) { + pairs.add(new Tuple2(word, word)); + } return pairs; } } @@ -499,7 +512,9 @@ public Iterable> call(String s) { @Override public Iterable call(String s) { List lengths = new LinkedList(); - for (String word : s.split(" ")) lengths.add(word.length() * 1.0); + for (String word : s.split(" ")) { + lengths.add((double) word.length()); + } return lengths; } }); @@ -521,7 +536,7 @@ public void mapsFromPairsToPairs() { JavaPairRDD swapped = pairRDD.flatMapToPair( new PairFlatMapFunction, String, Integer>() { @Override - public Iterable> call(Tuple2 item) throws Exception { + public Iterable> call(Tuple2 item) { return Collections.singletonList(item.swap()); } }); @@ -530,7 +545,7 @@ public Iterable> call(Tuple2 item) thro // There was never a bug here, but it's worth testing: pairRDD.mapToPair(new PairFunction, String, Integer>() { @Override - public Tuple2 call(Tuple2 item) throws Exception { + public Tuple2 call(Tuple2 item) { return item.swap(); } }).collect(); @@ -631,14 +646,10 @@ public void wholeTextFiles() throws IOException { byte[] content2 = "spark is also easy to use.\n".getBytes("utf-8"); String tempDirName = tempDir.getAbsolutePath(); - DataOutputStream ds = new DataOutputStream(new FileOutputStream(tempDirName + "/part-00000")); - ds.write(content1); - ds.close(); - ds = new DataOutputStream(new FileOutputStream(tempDirName + "/part-00001")); - ds.write(content2); - ds.close(); - - HashMap container = new HashMap(); + Files.write(content1, new File(tempDirName + "/part-00000")); + Files.write(content2, new File(tempDirName + "/part-00001")); + + Map container = new HashMap(); container.put(tempDirName+"/part-00000", new Text(content1).toString()); container.put(tempDirName+"/part-00001", new Text(content2).toString()); @@ -844,7 +855,7 @@ public void zip() { JavaDoubleRDD doubles = rdd.mapToDouble(new DoubleFunction() { @Override public double call(Integer x) { - return 1.0 * x; + return x.doubleValue(); } }); JavaPairRDD zipped = rdd.zip(doubles); @@ -859,17 +870,7 @@ public void zipPartitions() { new FlatMapFunction2, Iterator, Integer>() { @Override public Iterable call(Iterator i, Iterator s) { - int sizeI = 0; - int sizeS = 0; - while (i.hasNext()) { - sizeI += 1; - i.next(); - } - while (s.hasNext()) { - sizeS += 1; - s.next(); - } - return Arrays.asList(sizeI, sizeS); + return Arrays.asList(Iterators.size(i), Iterators.size(s)); } }; @@ -883,6 +884,7 @@ public void accumulators() { final Accumulator intAccum = sc.intAccumulator(10); rdd.foreach(new VoidFunction() { + @Override public void call(Integer x) { intAccum.add(x); } @@ -891,6 +893,7 @@ public void call(Integer x) { final Accumulator doubleAccum = sc.doubleAccumulator(10.0); rdd.foreach(new VoidFunction() { + @Override public void call(Integer x) { doubleAccum.add((double) x); } @@ -899,14 +902,17 @@ public void call(Integer x) { // Try a custom accumulator type AccumulatorParam floatAccumulatorParam = new AccumulatorParam() { + @Override public Float addInPlace(Float r, Float t) { return r + t; } + @Override public Float addAccumulator(Float r, Float t) { return r + t; } + @Override public Float zero(Float initialValue) { return 0.0f; } @@ -914,6 +920,7 @@ public Float zero(Float initialValue) { final Accumulator floatAccum = sc.accumulator(10.0f, floatAccumulatorParam); rdd.foreach(new VoidFunction() { + @Override public void call(Integer x) { floatAccum.add((float) x); } @@ -929,7 +936,8 @@ public void call(Integer x) { public void keyBy() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2)); List> s = rdd.keyBy(new Function() { - public String call(Integer t) throws Exception { + @Override + public String call(Integer t) { return t.toString(); } }).collect(); @@ -941,10 +949,10 @@ public String call(Integer t) throws Exception { public void checkpointAndComputation() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); sc.setCheckpointDir(tempDir.getAbsolutePath()); - Assert.assertEquals(false, rdd.isCheckpointed()); + Assert.assertFalse(rdd.isCheckpointed()); rdd.checkpoint(); rdd.count(); // Forces the DAG to cause a checkpoint - Assert.assertEquals(true, rdd.isCheckpointed()); + Assert.assertTrue(rdd.isCheckpointed()); Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), rdd.collect()); } @@ -952,10 +960,10 @@ public void checkpointAndComputation() { public void checkpointAndRestore() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); sc.setCheckpointDir(tempDir.getAbsolutePath()); - Assert.assertEquals(false, rdd.isCheckpointed()); + Assert.assertFalse(rdd.isCheckpointed()); rdd.checkpoint(); rdd.count(); // Forces the DAG to cause a checkpoint - Assert.assertEquals(true, rdd.isCheckpointed()); + Assert.assertTrue(rdd.isCheckpointed()); Assert.assertTrue(rdd.getCheckpointFile().isPresent()); JavaRDD recovered = sc.checkpointFile(rdd.getCheckpointFile().get()); @@ -966,16 +974,17 @@ public void checkpointAndRestore() { @Test public void mapOnPairRDD() { JavaRDD rdd1 = sc.parallelize(Arrays.asList(1,2,3,4)); - JavaPairRDD rdd2 = rdd1.mapToPair(new PairFunction() { - @Override - public Tuple2 call(Integer i) throws Exception { - return new Tuple2(i, i % 2); - } - }); + JavaPairRDD rdd2 = rdd1.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(Integer i) { + return new Tuple2(i, i % 2); + } + }); JavaPairRDD rdd3 = rdd2.mapToPair( new PairFunction, Integer, Integer>() { @Override - public Tuple2 call(Tuple2 in) throws Exception { + public Tuple2 call(Tuple2 in) { return new Tuple2(in._2(), in._1()); } }); @@ -992,14 +1001,15 @@ public Tuple2 call(Tuple2 in) throws Excepti public void collectPartitions() { JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7), 3); - JavaPairRDD rdd2 = rdd1.mapToPair(new PairFunction() { - @Override - public Tuple2 call(Integer i) throws Exception { - return new Tuple2(i, i % 2); - } - }); + JavaPairRDD rdd2 = rdd1.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(Integer i) { + return new Tuple2(i, i % 2); + } + }); - List[] parts = rdd1.collectPartitions(new int[] {0}); + List[] parts = rdd1.collectPartitions(new int[] {0}); Assert.assertEquals(Arrays.asList(1, 2), parts[0]); parts = rdd1.collectPartitions(new int[] {1, 2}); @@ -1010,14 +1020,14 @@ public Tuple2 call(Integer i) throws Exception { new Tuple2(2, 0)), rdd2.collectPartitions(new int[] {0})[0]); - parts = rdd2.collectPartitions(new int[] {1, 2}); + List>[] parts2 = rdd2.collectPartitions(new int[] {1, 2}); Assert.assertEquals(Arrays.asList(new Tuple2(3, 1), new Tuple2(4, 0)), - parts[0]); + parts2[0]); Assert.assertEquals(Arrays.asList(new Tuple2(5, 1), new Tuple2(6, 0), new Tuple2(7, 1)), - parts[1]); + parts2[1]); } @Test @@ -1034,10 +1044,12 @@ public void countApproxDistinct() { @Test public void countApproxDistinctByKey() { List> arrayData = new ArrayList>(); - for (int i = 10; i < 100; i++) - for (int j = 0; j < i; j++) + for (int i = 10; i < 100; i++) { + for (int j = 0; j < i; j++) { arrayData.add(new Tuple2(i, j)); - + } + } + double relativeSD = 0.001; JavaPairRDD pairRdd = sc.parallelizePairs(arrayData); List> res = pairRdd.countApproxDistinctByKey(8, 0).collect(); for (Tuple2 resItem : res) { @@ -1053,12 +1065,13 @@ public void countApproxDistinctByKey() { public void collectAsMapWithIntArrayValues() { // Regression test for SPARK-1040 JavaRDD rdd = sc.parallelize(Arrays.asList(1)); - JavaPairRDD pairRDD = rdd.mapToPair(new PairFunction() { - @Override - public Tuple2 call(Integer x) throws Exception { - return new Tuple2(x, new int[] { x }); - } - }); + JavaPairRDD pairRDD = rdd.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(Integer x) { + return new Tuple2(x, new int[] { x }); + } + }); pairRDD.collect(); // Works fine pairRDD.collectAsMap(); // Used to crash with ClassCastException } diff --git a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java index c366c10b15a20..729bc0459ce52 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java @@ -99,16 +99,16 @@ public void groupBy() { @Test public void leftOuterJoin() { JavaPairRDD rdd1 = sc.parallelizePairs(Arrays.asList( - new Tuple2(1, 1), - new Tuple2(1, 2), - new Tuple2(2, 1), - new Tuple2(3, 1) + new Tuple2<>(1, 1), + new Tuple2<>(1, 2), + new Tuple2<>(2, 1), + new Tuple2<>(3, 1) )); JavaPairRDD rdd2 = sc.parallelizePairs(Arrays.asList( - new Tuple2(1, 'x'), - new Tuple2(2, 'y'), - new Tuple2(2, 'z'), - new Tuple2(4, 'w') + new Tuple2<>(1, 'x'), + new Tuple2<>(2, 'y'), + new Tuple2<>(2, 'z'), + new Tuple2<>(4, 'w') )); List>>> joined = rdd1.leftOuterJoin(rdd2).collect(); @@ -133,11 +133,11 @@ public void foldReduce() { @Test public void foldByKey() { List> pairs = Arrays.asList( - new Tuple2(2, 1), - new Tuple2(2, 1), - new Tuple2(1, 1), - new Tuple2(3, 2), - new Tuple2(3, 1) + new Tuple2<>(2, 1), + new Tuple2<>(2, 1), + new Tuple2<>(1, 1), + new Tuple2<>(3, 2), + new Tuple2<>(3, 1) ); JavaPairRDD rdd = sc.parallelizePairs(pairs); JavaPairRDD sums = rdd.foldByKey(0, (a, b) -> a + b); @@ -149,11 +149,11 @@ public void foldByKey() { @Test public void reduceByKey() { List> pairs = Arrays.asList( - new Tuple2(2, 1), - new Tuple2(2, 1), - new Tuple2(1, 1), - new Tuple2(3, 2), - new Tuple2(3, 1) + new Tuple2<>(2, 1), + new Tuple2<>(2, 1), + new Tuple2<>(1, 1), + new Tuple2<>(3, 2), + new Tuple2<>(3, 1) ); JavaPairRDD rdd = sc.parallelizePairs(pairs); JavaPairRDD counts = rdd.reduceByKey((a, b) -> a + b); @@ -177,7 +177,7 @@ public void map() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); JavaDoubleRDD doubles = rdd.mapToDouble(x -> 1.0 * x).cache(); doubles.collect(); - JavaPairRDD pairs = rdd.mapToPair(x -> new Tuple2(x, x)) + JavaPairRDD pairs = rdd.mapToPair(x -> new Tuple2<>(x, x)) .cache(); pairs.collect(); JavaRDD strings = rdd.map(x -> x.toString()).cache(); @@ -194,31 +194,31 @@ public void flatMap() { Assert.assertEquals(11, words.count()); JavaPairRDD pairs = rdd.flatMapToPair(s -> { - List> pairs2 = new LinkedList>(); - for (String word : s.split(" ")) pairs2.add(new Tuple2(word, word)); + List> pairs2 = new LinkedList<>(); + for (String word : s.split(" ")) pairs2.add(new Tuple2<>(word, word)); return pairs2; }); - Assert.assertEquals(new Tuple2("Hello", "Hello"), pairs.first()); + Assert.assertEquals(new Tuple2<>("Hello", "Hello"), pairs.first()); Assert.assertEquals(11, pairs.count()); JavaDoubleRDD doubles = rdd.flatMapToDouble(s -> { - List lengths = new LinkedList(); + List lengths = new LinkedList<>(); for (String word : s.split(" ")) lengths.add(word.length() * 1.0); return lengths; }); Double x = doubles.first(); - Assert.assertEquals(5.0, doubles.first().doubleValue(), 0.01); + Assert.assertEquals(5.0, doubles.first(), 0.01); Assert.assertEquals(11, pairs.count()); } @Test public void mapsFromPairsToPairs() { List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD pairRDD = sc.parallelizePairs(pairs); @@ -251,19 +251,18 @@ public void sequenceFile() { tempDir.deleteOnExit(); String outputDir = new File(tempDir, "output").getAbsolutePath(); List> pairs = Arrays.asList( - new Tuple2(1, "a"), - new Tuple2(2, "aa"), - new Tuple2(3, "aaa") + new Tuple2<>(1, "a"), + new Tuple2<>(2, "aa"), + new Tuple2<>(3, "aaa") ); JavaPairRDD rdd = sc.parallelizePairs(pairs); - rdd.mapToPair(pair -> - new Tuple2(new IntWritable(pair._1()), new Text(pair._2()))) + rdd.mapToPair(pair -> new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2()))) .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); // Try reading the output back as an object file JavaPairRDD readRDD = sc.sequenceFile(outputDir, IntWritable.class, Text.class) - .mapToPair(pair -> new Tuple2(pair._1().get(), pair._2().toString())); + .mapToPair(pair -> new Tuple2<>(pair._1().get(), pair._2().toString())); Assert.assertEquals(pairs, readRDD.collect()); Utils.deleteRecursively(tempDir); } @@ -325,7 +324,7 @@ public Float zero(Float initialValue) { } }; - final Accumulator floatAccum = sc.accumulator((Float) 10.0f, floatAccumulatorParam); + final Accumulator floatAccum = sc.accumulator(10.0f, floatAccumulatorParam); rdd.foreach(x -> floatAccum.add((float) x)); Assert.assertEquals((Float) 25.0f, floatAccum.value()); @@ -338,22 +337,22 @@ public Float zero(Float initialValue) { public void keyBy() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2)); List> s = rdd.keyBy(x -> x.toString()).collect(); - Assert.assertEquals(new Tuple2("1", 1), s.get(0)); - Assert.assertEquals(new Tuple2("2", 2), s.get(1)); + Assert.assertEquals(new Tuple2<>("1", 1), s.get(0)); + Assert.assertEquals(new Tuple2<>("2", 2), s.get(1)); } @Test public void mapOnPairRDD() { JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4)); JavaPairRDD rdd2 = - rdd1.mapToPair(i -> new Tuple2(i, i % 2)); + rdd1.mapToPair(i -> new Tuple2<>(i, i % 2)); JavaPairRDD rdd3 = - rdd2.mapToPair(in -> new Tuple2(in._2(), in._1())); + rdd2.mapToPair(in -> new Tuple2<>(in._2(), in._1())); Assert.assertEquals(Arrays.asList( new Tuple2(1, 1), - new Tuple2(0, 2), - new Tuple2(1, 3), - new Tuple2(0, 4)), rdd3.collect()); + new Tuple2<>(0, 2), + new Tuple2<>(1, 3), + new Tuple2<>(0, 4)), rdd3.collect()); } @Test @@ -361,7 +360,7 @@ public void collectPartitions() { JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7), 3); JavaPairRDD rdd2 = - rdd1.mapToPair(i -> new Tuple2(i, i % 2)); + rdd1.mapToPair(i -> new Tuple2<>(i, i % 2)); List[] parts = rdd1.collectPartitions(new int[]{0}); Assert.assertEquals(Arrays.asList(1, 2), parts[0]); @@ -369,16 +368,13 @@ public void collectPartitions() { Assert.assertEquals(Arrays.asList(3, 4), parts[0]); Assert.assertEquals(Arrays.asList(5, 6, 7), parts[1]); - Assert.assertEquals(Arrays.asList(new Tuple2(1, 1), - new Tuple2(2, 0)), + Assert.assertEquals(Arrays.asList(new Tuple2<>(1, 1), new Tuple2<>(2, 0)), rdd2.collectPartitions(new int[]{0})[0]); parts = rdd2.collectPartitions(new int[]{1, 2}); - Assert.assertEquals(Arrays.asList(new Tuple2(3, 1), - new Tuple2(4, 0)), parts[0]); - Assert.assertEquals(Arrays.asList(new Tuple2(5, 1), - new Tuple2(6, 0), - new Tuple2(7, 1)), parts[1]); + Assert.assertEquals(Arrays.asList(new Tuple2<>(3, 1), new Tuple2<>(4, 0)), parts[0]); + Assert.assertEquals(Arrays.asList(new Tuple2<>(5, 1), new Tuple2<>(6, 0), new Tuple2<>(7, 1)), + parts[1]); } @Test @@ -386,7 +382,7 @@ public void collectAsMapWithIntArrayValues() { // Regression test for SPARK-1040 JavaRDD rdd = sc.parallelize(Arrays.asList(new Integer[]{1})); JavaPairRDD pairRDD = - rdd.mapToPair(x -> new Tuple2(x, new int[]{x})); + rdd.mapToPair(x -> new Tuple2<>(x, new int[]{x})); pairRDD.collect(); // Works fine Map map = pairRDD.collectAsMap(); // Used to crash with ClassCastException } diff --git a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java index 43df0dea614bc..73091cfe2c09e 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java @@ -39,6 +39,7 @@ * Most of these tests replicate org.apache.spark.streaming.JavaAPISuite using java 8 * lambda syntax. */ +@SuppressWarnings("unchecked") public class Java8APISuite extends LocalJavaStreamingContext implements Serializable { @Test @@ -52,7 +53,7 @@ public void testMap() { Arrays.asList(9, 4)); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream letterCount = stream.map(s -> s.length()); + JavaDStream letterCount = stream.map(String::length); JavaTestUtils.attachTestOutputStream(letterCount); List> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -63,7 +64,7 @@ public void testMap() { public void testFilter() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List> expected = Arrays.asList( Arrays.asList("giants"), @@ -81,11 +82,11 @@ public void testFilter() { public void testMapPartitions() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List> expected = Arrays.asList( Arrays.asList("GIANTSDODGERS"), - Arrays.asList("YANKEESRED SOCKS")); + Arrays.asList("YANKEESRED SOX")); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream mapped = stream.mapPartitions(in -> { @@ -172,7 +173,7 @@ public void testVariousTransform() { JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); List>> pairInputData = - Arrays.asList(Arrays.asList(new Tuple2("x", 1))); + Arrays.asList(Arrays.asList(new Tuple2<>("x", 1))); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairInputData, 1)); @@ -192,32 +193,32 @@ public void testVariousTransform() { public void testTransformWith() { List>> stringStringKVStream1 = Arrays.asList( Arrays.asList( - new Tuple2("california", "dodgers"), - new Tuple2("new york", "yankees")), + new Tuple2<>("california", "dodgers"), + new Tuple2<>("new york", "yankees")), Arrays.asList( - new Tuple2("california", "sharks"), - new Tuple2("new york", "rangers"))); + new Tuple2<>("california", "sharks"), + new Tuple2<>("new york", "rangers"))); List>> stringStringKVStream2 = Arrays.asList( Arrays.asList( - new Tuple2("california", "giants"), - new Tuple2("new york", "mets")), + new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "mets")), Arrays.asList( - new Tuple2("california", "ducks"), - new Tuple2("new york", "islanders"))); + new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "islanders"))); - List>>> expected = Arrays.asList( + List>>> expected = Arrays.asList( Sets.newHashSet( - new Tuple2>("california", - new Tuple2("dodgers", "giants")), - new Tuple2>("new york", - new Tuple2("yankees", "mets"))), + new Tuple2<>("california", + new Tuple2<>("dodgers", "giants")), + new Tuple2<>("new york", + new Tuple2<>("yankees", "mets"))), Sets.newHashSet( - new Tuple2>("california", - new Tuple2("sharks", "ducks")), - new Tuple2>("new york", - new Tuple2("rangers", "islanders")))); + new Tuple2<>("california", + new Tuple2<>("sharks", "ducks")), + new Tuple2<>("new york", + new Tuple2<>("rangers", "islanders")))); JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( ssc, stringStringKVStream1, 1); @@ -232,7 +233,7 @@ public void testTransformWith() { JavaTestUtils.attachTestOutputStream(joined); List>>> result = JavaTestUtils.runStreams(ssc, 2, 2); - List>>> unorderedResult = Lists.newArrayList(); + List>>> unorderedResult = Lists.newArrayList(); for (List>> res : result) { unorderedResult.add(Sets.newHashSet(res)); } @@ -251,9 +252,9 @@ public void testVariousTransformWith() { JavaDStream stream2 = JavaTestUtils.attachTestInputStream(ssc, inputData2, 1); List>> pairInputData1 = - Arrays.asList(Arrays.asList(new Tuple2("x", 1))); + Arrays.asList(Arrays.asList(new Tuple2<>("x", 1))); List>> pairInputData2 = - Arrays.asList(Arrays.asList(new Tuple2(1.0, 'x'))); + Arrays.asList(Arrays.asList(new Tuple2<>(1.0, 'x'))); JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream( JavaTestUtils.attachTestInputStream(ssc, pairInputData1, 1)); JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream( @@ -293,13 +294,13 @@ public void testStreamingContextTransform() { ); List>> pairStream1input = Arrays.asList( - Arrays.asList(new Tuple2(1, "x")), - Arrays.asList(new Tuple2(2, "y")) + Arrays.asList(new Tuple2<>(1, "x")), + Arrays.asList(new Tuple2<>(2, "y")) ); List>>> expected = Arrays.asList( - Arrays.asList(new Tuple2>(1, new Tuple2(1, "x"))), - Arrays.asList(new Tuple2>(2, new Tuple2(2, "y"))) + Arrays.asList(new Tuple2<>(1, new Tuple2<>(1, "x"))), + Arrays.asList(new Tuple2<>(2, new Tuple2<>(2, "y"))) ); JavaDStream stream1 = JavaTestUtils.attachTestInputStream(ssc, stream1input, 1); @@ -312,7 +313,7 @@ public void testStreamingContextTransform() { // This is just to test whether this transform to JavaStream compiles JavaDStream transformed1 = ssc.transform( listOfDStreams1, (List> listOfRDDs, Time time) -> { - assert (listOfRDDs.size() == 2); + Assert.assertEquals(2, listOfRDDs.size()); return null; }); @@ -321,13 +322,13 @@ public void testStreamingContextTransform() { JavaPairDStream> transformed2 = ssc.transformToPair( listOfDStreams2, (List> listOfRDDs, Time time) -> { - assert (listOfRDDs.size() == 3); + Assert.assertEquals(3, listOfRDDs.size()); JavaRDD rdd1 = (JavaRDD) listOfRDDs.get(0); JavaRDD rdd2 = (JavaRDD) listOfRDDs.get(1); JavaRDD> rdd3 = (JavaRDD>) listOfRDDs.get(2); JavaPairRDD prdd3 = JavaPairRDD.fromJavaRDD(rdd3); PairFunction mapToTuple = - (Integer i) -> new Tuple2(i, i); + (Integer i) -> new Tuple2<>(i, i); return rdd1.union(rdd2).mapToPair(mapToTuple).join(prdd3); }); JavaTestUtils.attachTestOutputStream(transformed2); @@ -365,36 +366,36 @@ public void testPairFlatMap() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(6, "g"), - new Tuple2(6, "i"), - new Tuple2(6, "a"), - new Tuple2(6, "n"), - new Tuple2(6, "t"), - new Tuple2(6, "s")), + new Tuple2<>(6, "g"), + new Tuple2<>(6, "i"), + new Tuple2<>(6, "a"), + new Tuple2<>(6, "n"), + new Tuple2<>(6, "t"), + new Tuple2<>(6, "s")), Arrays.asList( - new Tuple2(7, "d"), - new Tuple2(7, "o"), - new Tuple2(7, "d"), - new Tuple2(7, "g"), - new Tuple2(7, "e"), - new Tuple2(7, "r"), - new Tuple2(7, "s")), + new Tuple2<>(7, "d"), + new Tuple2<>(7, "o"), + new Tuple2<>(7, "d"), + new Tuple2<>(7, "g"), + new Tuple2<>(7, "e"), + new Tuple2<>(7, "r"), + new Tuple2<>(7, "s")), Arrays.asList( - new Tuple2(9, "a"), - new Tuple2(9, "t"), - new Tuple2(9, "h"), - new Tuple2(9, "l"), - new Tuple2(9, "e"), - new Tuple2(9, "t"), - new Tuple2(9, "i"), - new Tuple2(9, "c"), - new Tuple2(9, "s"))); + new Tuple2<>(9, "a"), + new Tuple2<>(9, "t"), + new Tuple2<>(9, "h"), + new Tuple2<>(9, "l"), + new Tuple2<>(9, "e"), + new Tuple2<>(9, "t"), + new Tuple2<>(9, "i"), + new Tuple2<>(9, "c"), + new Tuple2<>(9, "s"))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream flatMapped = stream.flatMapToPair(s -> { List> out = Lists.newArrayList(); for (String letter : s.split("(?!^)")) { - out.add(new Tuple2(s.length(), letter)); + out.add(new Tuple2<>(s.length(), letter)); } return out; }); @@ -411,12 +412,8 @@ public void testPairFlatMap() { */ public static > void assertOrderInvariantEquals( List> expected, List> actual) { - for (List list : expected) { - Collections.sort(list); - } - for (List list : actual) { - Collections.sort(list); - } + expected.forEach((List list) -> Collections.sort(list)); + actual.forEach((List list) -> Collections.sort(list)); Assert.assertEquals(expected, actual); } @@ -424,11 +421,11 @@ public static > void assertOrderInvariantEquals( public void testPairFilter() { List> inputData = Arrays.asList( Arrays.asList("giants", "dodgers"), - Arrays.asList("yankees", "red socks")); + Arrays.asList("yankees", "red sox")); List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("giants", 6)), - Arrays.asList(new Tuple2("yankees", 7))); + Arrays.asList(new Tuple2<>("giants", 6)), + Arrays.asList(new Tuple2<>("yankees", 7))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = @@ -441,26 +438,26 @@ public void testPairFilter() { } List>> stringStringKVStream = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers"), - new Tuple2("california", "giants"), - new Tuple2("new york", "yankees"), - new Tuple2("new york", "mets")), - Arrays.asList(new Tuple2("california", "sharks"), - new Tuple2("california", "ducks"), - new Tuple2("new york", "rangers"), - new Tuple2("new york", "islanders"))); + Arrays.asList(new Tuple2<>("california", "dodgers"), + new Tuple2<>("california", "giants"), + new Tuple2<>("new york", "yankees"), + new Tuple2<>("new york", "mets")), + Arrays.asList(new Tuple2<>("california", "sharks"), + new Tuple2<>("california", "ducks"), + new Tuple2<>("new york", "rangers"), + new Tuple2<>("new york", "islanders"))); List>> stringIntKVStream = Arrays.asList( Arrays.asList( - new Tuple2("california", 1), - new Tuple2("california", 3), - new Tuple2("new york", 4), - new Tuple2("new york", 1)), + new Tuple2<>("california", 1), + new Tuple2<>("california", 3), + new Tuple2<>("new york", 4), + new Tuple2<>("new york", 1)), Arrays.asList( - new Tuple2("california", 5), - new Tuple2("california", 5), - new Tuple2("new york", 3), - new Tuple2("new york", 1))); + new Tuple2<>("california", 5), + new Tuple2<>("california", 5), + new Tuple2<>("new york", 3), + new Tuple2<>("new york", 1))); @Test public void testPairMap() { // Maps pair -> pair of different type @@ -468,15 +465,15 @@ public void testPairMap() { // Maps pair -> pair of different type List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, "california"), - new Tuple2(3, "california"), - new Tuple2(4, "new york"), - new Tuple2(1, "new york")), + new Tuple2<>(1, "california"), + new Tuple2<>(3, "california"), + new Tuple2<>(4, "new york"), + new Tuple2<>(1, "new york")), Arrays.asList( - new Tuple2(5, "california"), - new Tuple2(5, "california"), - new Tuple2(3, "new york"), - new Tuple2(1, "new york"))); + new Tuple2<>(5, "california"), + new Tuple2<>(5, "california"), + new Tuple2<>(3, "new york"), + new Tuple2<>(1, "new york"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); @@ -494,21 +491,21 @@ public void testPairMapPartitions() { // Maps pair -> pair of different type List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, "california"), - new Tuple2(3, "california"), - new Tuple2(4, "new york"), - new Tuple2(1, "new york")), + new Tuple2<>(1, "california"), + new Tuple2<>(3, "california"), + new Tuple2<>(4, "new york"), + new Tuple2<>(1, "new york")), Arrays.asList( - new Tuple2(5, "california"), - new Tuple2(5, "california"), - new Tuple2(3, "new york"), - new Tuple2(1, "new york"))); + new Tuple2<>(5, "california"), + new Tuple2<>(5, "california"), + new Tuple2<>(3, "new york"), + new Tuple2<>(1, "new york"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reversed = pairStream.mapPartitionsToPair(in -> { - LinkedList> out = new LinkedList>(); + LinkedList> out = new LinkedList<>(); while (in.hasNext()) { Tuple2 next = in.next(); out.add(next.swap()); @@ -530,7 +527,8 @@ public void testPairMap2() { // Maps pair -> single Arrays.asList(1, 3, 4, 1), Arrays.asList(5, 5, 3, 1)); - JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaDStream reversed = pairStream.map(in -> in._2()); JavaTestUtils.attachTestOutputStream(reversed); @@ -543,31 +541,31 @@ public void testPairMap2() { // Maps pair -> single public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair List>> inputData = Arrays.asList( Arrays.asList( - new Tuple2("hi", 1), - new Tuple2("ho", 2)), + new Tuple2<>("hi", 1), + new Tuple2<>("ho", 2)), Arrays.asList( - new Tuple2("hi", 1), - new Tuple2("ho", 2))); + new Tuple2<>("hi", 1), + new Tuple2<>("ho", 2))); List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, "h"), - new Tuple2(1, "i"), - new Tuple2(2, "h"), - new Tuple2(2, "o")), + new Tuple2<>(1, "h"), + new Tuple2<>(1, "i"), + new Tuple2<>(2, "h"), + new Tuple2<>(2, "o")), Arrays.asList( - new Tuple2(1, "h"), - new Tuple2(1, "i"), - new Tuple2(2, "h"), - new Tuple2(2, "o"))); + new Tuple2<>(1, "h"), + new Tuple2<>(1, "i"), + new Tuple2<>(2, "h"), + new Tuple2<>(2, "o"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream flatMapped = pairStream.flatMapToPair(in -> { - List> out = new LinkedList>(); + List> out = new LinkedList<>(); for (Character s : in._1().toCharArray()) { - out.add(new Tuple2(in._2(), s.toString())); + out.add(new Tuple2<>(in._2(), s.toString())); } return out; }); @@ -584,11 +582,11 @@ public void testPairReduceByKey() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2("california", 4), - new Tuple2("new york", 5)), + new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), Arrays.asList( - new Tuple2("california", 10), - new Tuple2("new york", 4))); + new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -608,11 +606,11 @@ public void testCombineByKey() { List>> expected = Arrays.asList( Arrays.asList( - new Tuple2("california", 4), - new Tuple2("new york", 5)), + new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), Arrays.asList( - new Tuple2("california", 10), - new Tuple2("new york", 4))); + new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -632,12 +630,12 @@ public void testReduceByKeyAndWindow() { List>> inputData = stringIntKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 10), - new Tuple2("new york", 4))); + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); @@ -656,12 +654,12 @@ public void testUpdateStateByKey() { List>> inputData = stringIntKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9))); + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); @@ -689,12 +687,12 @@ public void testReduceByKeyAndWindowWithInverse() { List>> inputData = stringIntKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", 4), - new Tuple2("new york", 5)), - Arrays.asList(new Tuple2("california", 14), - new Tuple2("new york", 9)), - Arrays.asList(new Tuple2("california", 10), - new Tuple2("new york", 4))); + Arrays.asList(new Tuple2<>("california", 4), + new Tuple2<>("new york", 5)), + Arrays.asList(new Tuple2<>("california", 14), + new Tuple2<>("new york", 9)), + Arrays.asList(new Tuple2<>("california", 10), + new Tuple2<>("new york", 4))); JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); @@ -713,27 +711,27 @@ public void testReduceByKeyAndWindowWithInverse() { public void testPairTransform() { List>> inputData = Arrays.asList( Arrays.asList( - new Tuple2(3, 5), - new Tuple2(1, 5), - new Tuple2(4, 5), - new Tuple2(2, 5)), + new Tuple2<>(3, 5), + new Tuple2<>(1, 5), + new Tuple2<>(4, 5), + new Tuple2<>(2, 5)), Arrays.asList( - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5), - new Tuple2(1, 5))); + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5), + new Tuple2<>(1, 5))); List>> expected = Arrays.asList( Arrays.asList( - new Tuple2(1, 5), - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5)), + new Tuple2<>(1, 5), + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5)), Arrays.asList( - new Tuple2(1, 5), - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5))); + new Tuple2<>(1, 5), + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); @@ -751,15 +749,15 @@ public void testPairTransform() { public void testPairToNormalRDDTransform() { List>> inputData = Arrays.asList( Arrays.asList( - new Tuple2(3, 5), - new Tuple2(1, 5), - new Tuple2(4, 5), - new Tuple2(2, 5)), + new Tuple2<>(3, 5), + new Tuple2<>(1, 5), + new Tuple2<>(4, 5), + new Tuple2<>(2, 5)), Arrays.asList( - new Tuple2(2, 5), - new Tuple2(3, 5), - new Tuple2(4, 5), - new Tuple2(1, 5))); + new Tuple2<>(2, 5), + new Tuple2<>(3, 5), + new Tuple2<>(4, 5), + new Tuple2<>(1, 5))); List> expected = Arrays.asList( Arrays.asList(3, 1, 4, 2), @@ -780,20 +778,20 @@ public void testMapValues() { List>> inputData = stringStringKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", "DODGERS"), - new Tuple2("california", "GIANTS"), - new Tuple2("new york", "YANKEES"), - new Tuple2("new york", "METS")), - Arrays.asList(new Tuple2("california", "SHARKS"), - new Tuple2("california", "DUCKS"), - new Tuple2("new york", "RANGERS"), - new Tuple2("new york", "ISLANDERS"))); + Arrays.asList(new Tuple2<>("california", "DODGERS"), + new Tuple2<>("california", "GIANTS"), + new Tuple2<>("new york", "YANKEES"), + new Tuple2<>("new york", "METS")), + Arrays.asList(new Tuple2<>("california", "SHARKS"), + new Tuple2<>("california", "DUCKS"), + new Tuple2<>("new york", "RANGERS"), + new Tuple2<>("new york", "ISLANDERS"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream mapped = pairStream.mapValues(s -> s.toUpperCase()); + JavaPairDStream mapped = pairStream.mapValues(String::toUpperCase); JavaTestUtils.attachTestOutputStream(mapped); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -805,34 +803,29 @@ public void testFlatMapValues() { List>> inputData = stringStringKVStream; List>> expected = Arrays.asList( - Arrays.asList(new Tuple2("california", "dodgers1"), - new Tuple2("california", "dodgers2"), - new Tuple2("california", "giants1"), - new Tuple2("california", "giants2"), - new Tuple2("new york", "yankees1"), - new Tuple2("new york", "yankees2"), - new Tuple2("new york", "mets1"), - new Tuple2("new york", "mets2")), - Arrays.asList(new Tuple2("california", "sharks1"), - new Tuple2("california", "sharks2"), - new Tuple2("california", "ducks1"), - new Tuple2("california", "ducks2"), - new Tuple2("new york", "rangers1"), - new Tuple2("new york", "rangers2"), - new Tuple2("new york", "islanders1"), - new Tuple2("new york", "islanders2"))); + Arrays.asList(new Tuple2<>("california", "dodgers1"), + new Tuple2<>("california", "dodgers2"), + new Tuple2<>("california", "giants1"), + new Tuple2<>("california", "giants2"), + new Tuple2<>("new york", "yankees1"), + new Tuple2<>("new york", "yankees2"), + new Tuple2<>("new york", "mets1"), + new Tuple2<>("new york", "mets2")), + Arrays.asList(new Tuple2<>("california", "sharks1"), + new Tuple2<>("california", "sharks2"), + new Tuple2<>("california", "ducks1"), + new Tuple2<>("california", "ducks2"), + new Tuple2<>("new york", "rangers1"), + new Tuple2<>("new york", "rangers2"), + new Tuple2<>("new york", "islanders1"), + new Tuple2<>("new york", "islanders2"))); JavaDStream> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - - JavaPairDStream flatMapped = pairStream.flatMapValues(in -> { - List out = new ArrayList(); - out.add(in + "1"); - out.add(in + "2"); - return out; - }); + JavaPairDStream flatMapped = + pairStream.flatMapValues(in -> Arrays.asList(in + "1", in + "2")); JavaTestUtils.attachTestOutputStream(flatMapped); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); Assert.assertEquals(expected, result); From 189df165bb7cb8bc8ede48d0e7f8d8b5cd31d299 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 4 Jun 2014 12:56:56 -0700 Subject: [PATCH 02/18] [SPARK-1752][MLLIB] Standardize text format for vectors and labeled points We should standardize the text format used to represent vectors and labeled points. The proposed formats are the following: 1. dense vector: `[v0,v1,..]` 2. sparse vector: `(size,[i0,i1],[v0,v1])` 3. labeled point: `(label,vector)` where "(..)" indicates a tuple and "[...]" indicate an array. `loadLabeledPoints` is added to pyspark's `MLUtils`. I didn't add `loadVectors` to pyspark because `RDD.saveAsTextFile` cannot stringify dense vectors in the proposed format automatically. `MLUtils#saveLabeledData` and `MLUtils#loadLabeledData` are deprecated. Users should use `RDD#saveAsTextFile` and `MLUtils#loadLabeledPoints` instead. In Scala, `MLUtils#loadLabeledPoints` is compatible with the format used by `MLUtils#loadLabeledData`. CC: @mateiz, @srowen Author: Xiangrui Meng Closes #685 from mengxr/labeled-io and squashes the following commits: 2d1116a [Xiangrui Meng] make loadLabeledData/saveLabeledData deprecated since 1.0.1 297be75 [Xiangrui Meng] change LabeledPoint.parse to LabeledPointParser.parse to maintain binary compatibility d6b1473 [Xiangrui Meng] Merge branch 'master' into labeled-io 56746ea [Xiangrui Meng] replace # by . 623a5f0 [Xiangrui Meng] merge master f06d5ba [Xiangrui Meng] add docs and minor updates 640fe0c [Xiangrui Meng] throw SparkException 5bcfbc4 [Xiangrui Meng] update test to add scientific notations e86bf38 [Xiangrui Meng] remove NumericTokenizer 050fca4 [Xiangrui Meng] use StringTokenizer 6155b75 [Xiangrui Meng] merge master f644438 [Xiangrui Meng] remove parse methods based on eval from pyspark a41675a [Xiangrui Meng] python loadLabeledPoint uses Scala's implementation ce9a475 [Xiangrui Meng] add deserialize_labeled_point to pyspark with tests e9fcd49 [Xiangrui Meng] add serializeLabeledPoint and tests aea4ae3 [Xiangrui Meng] minor updates 810d6df [Xiangrui Meng] update tokenizer/parser implementation 7aac03a [Xiangrui Meng] remove Scala parsers c1885c1 [Xiangrui Meng] add headers and minor changes b0c50cb [Xiangrui Meng] add customized parser d731817 [Xiangrui Meng] style update 63dc396 [Xiangrui Meng] add loadLabeledPoints to pyspark ea122b5 [Xiangrui Meng] Merge branch 'master' into labeled-io cd6c78f [Xiangrui Meng] add __str__ and parse to LabeledPoint a7a178e [Xiangrui Meng] add stringify to pyspark's Vectors 5c2dbfa [Xiangrui Meng] add parse to pyspark's Vectors 7853f88 [Xiangrui Meng] update pyspark's SparseVector.__str__ e761d32 [Xiangrui Meng] make LabelPoint.parse compatible with the dense format used before v1.0 and deprecate loadLabeledData and saveLabeledData 9e63a02 [Xiangrui Meng] add loadVectors and loadLabeledPoints 19aa523 [Xiangrui Meng] update toString and add parsers for Vectors and LabeledPoint --- .../examples/mllib/DecisionTreeRunner.scala | 2 +- .../mllib/api/python/PythonMLLibAPI.scala | 33 ++++- .../apache/spark/mllib/linalg/Vectors.scala | 33 ++++- .../spark/mllib/regression/LabeledPoint.scala | 31 ++++- .../mllib/util/LinearDataGenerator.scala | 3 +- .../LogisticRegressionDataGenerator.scala | 3 +- .../org/apache/spark/mllib/util/MLUtils.scala | 47 ++++++- .../spark/mllib/util/NumericParser.scala | 121 ++++++++++++++++++ .../spark/mllib/util/SVMDataGenerator.scala | 2 +- .../api/python/PythonMLLibAPISuite.scala | 60 +++++++++ .../spark/mllib/linalg/VectorsSuite.scala | 25 ++++ .../mllib/regression/LabeledPointSuite.scala | 39 ++++++ .../spark/mllib/util/MLUtilsSuite.scala | 30 ++++- .../spark/mllib/util/NumericParserSuite.scala | 42 ++++++ python/pyspark/mllib/_common.py | 72 ++++++++--- python/pyspark/mllib/linalg.py | 34 +++-- python/pyspark/mllib/regression.py | 5 +- python/pyspark/mllib/util.py | 69 +++++++--- 18 files changed, 579 insertions(+), 72 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 9832bec90d7ee..b3cc361154198 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -99,7 +99,7 @@ object DecisionTreeRunner { val sc = new SparkContext(conf) // Load training data and cache it. - val examples = MLUtils.loadLabeledData(sc, params.input).cache() + val examples = MLUtils.loadLabeledPoints(sc, params.input).cache() val splits = examples.randomSplit(Array(0.8, 0.2)) val training = splits(0).cache() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 7c65b0d4750fa..c44173793b39a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -20,12 +20,13 @@ package org.apache.spark.mllib.api.python import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ +import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD /** @@ -41,7 +42,7 @@ class PythonMLLibAPI extends Serializable { private val DENSE_MATRIX_MAGIC: Byte = 3 private val LABELED_POINT_MAGIC: Byte = 4 - private def deserializeDoubleVector(bytes: Array[Byte], offset: Int = 0): Vector = { + private[python] def deserializeDoubleVector(bytes: Array[Byte], offset: Int = 0): Vector = { require(bytes.length - offset >= 5, "Byte array too short") val magic = bytes(offset) if (magic == DENSE_VECTOR_MAGIC) { @@ -116,7 +117,7 @@ class PythonMLLibAPI extends Serializable { bytes } - private def serializeDoubleVector(vector: Vector): Array[Byte] = vector match { + private[python] def serializeDoubleVector(vector: Vector): Array[Byte] = vector match { case s: SparseVector => serializeSparseVector(s) case _ => @@ -167,7 +168,18 @@ class PythonMLLibAPI extends Serializable { bytes } - private def deserializeLabeledPoint(bytes: Array[Byte]): LabeledPoint = { + private[python] def serializeLabeledPoint(p: LabeledPoint): Array[Byte] = { + val fb = serializeDoubleVector(p.features) + val bytes = new Array[Byte](1 + 8 + fb.length) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.put(LABELED_POINT_MAGIC) + bb.putDouble(p.label) + bb.put(fb) + bytes + } + + private[python] def deserializeLabeledPoint(bytes: Array[Byte]): LabeledPoint = { require(bytes.length >= 9, "Byte array too short") val magic = bytes(0) if (magic != LABELED_POINT_MAGIC) { @@ -179,6 +191,19 @@ class PythonMLLibAPI extends Serializable { LabeledPoint(label, deserializeDoubleVector(bytes, 9)) } + /** + * Loads and serializes labeled points saved with `RDD#saveAsTextFile`. + * @param jsc Java SparkContext + * @param path file or directory path in any Hadoop-supported file system URI + * @param minPartitions min number of partitions + * @return serialized labeled points stored in a JavaRDD of byte array + */ + def loadLabeledPoints( + jsc: JavaSparkContext, + path: String, + minPartitions: Int): JavaRDD[Array[Byte]] = + MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions).map(serializeLabeledPoint).toJavaRDD() + private def trainRegressionModel( trainFunc: (RDD[LabeledPoint], Vector) => GeneralizedLinearModel, dataBytesJRDD: JavaRDD[Array[Byte]], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 84d223908c1f6..c818a0b9c3e43 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -17,13 +17,16 @@ package org.apache.spark.mllib.linalg -import java.lang.{Iterable => JavaIterable, Integer => JavaInteger, Double => JavaDouble} +import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable} import java.util.Arrays import scala.annotation.varargs import scala.collection.JavaConverters._ -import breeze.linalg.{Vector => BV, DenseVector => BDV, SparseVector => BSV} +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} + +import org.apache.spark.mllib.util.NumericParser +import org.apache.spark.SparkException /** * Represents a numeric vector, whose index type is Int and value type is Double. @@ -124,6 +127,25 @@ object Vectors { }.toSeq) } + /** + * Parses a string resulted from `Vector#toString` into + * an [[org.apache.spark.mllib.linalg.Vector]]. + */ + def parse(s: String): Vector = { + parseNumeric(NumericParser.parse(s)) + } + + private[mllib] def parseNumeric(any: Any): Vector = { + any match { + case values: Array[Double] => + Vectors.dense(values) + case Seq(size: Double, indices: Array[Double], values: Array[Double]) => + Vectors.sparse(size.toInt, indices.map(_.toInt), values) + case other => + throw new SparkException(s"Cannot parse $other.") + } + } + /** * Creates a vector instance from a breeze vector. */ @@ -175,9 +197,10 @@ class SparseVector( val indices: Array[Int], val values: Array[Double]) extends Vector { - override def toString: String = { - "(" + size + "," + indices.zip(values).mkString("[", "," ,"]") + ")" - } + require(indices.length == values.length) + + override def toString: String = + "(%s,%s,%s)".format(size, indices.mkString("[", ",", "]"), values.mkString("[", ",", "]")) override def toArray: Array[Double] = { val data = new Array[Double](size) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index 3deab1ab785b9..62a03af4a9964 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -17,7 +17,9 @@ package org.apache.spark.mllib.regression -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.util.NumericParser +import org.apache.spark.SparkException /** * Class that represents the features and labels of a data point. @@ -27,6 +29,31 @@ import org.apache.spark.mllib.linalg.Vector */ case class LabeledPoint(label: Double, features: Vector) { override def toString: String = { - "LabeledPoint(%s, %s)".format(label, features) + "(%s,%s)".format(label, features) + } +} + +/** + * Parser for [[org.apache.spark.mllib.regression.LabeledPoint]]. + */ +private[mllib] object LabeledPointParser { + /** + * Parses a string resulted from `LabeledPoint#toString` into + * an [[org.apache.spark.mllib.regression.LabeledPoint]]. + */ + def parse(s: String): LabeledPoint = { + if (s.startsWith("(")) { + NumericParser.parse(s) match { + case Seq(label: Double, numeric: Any) => + LabeledPoint(label, Vectors.parseNumeric(numeric)) + case other => + throw new SparkException(s"Cannot parse $other.") + } + } else { // dense format used before v1.0 + val parts = s.split(',') + val label = java.lang.Double.parseDouble(parts(0)) + val features = Vectors.dense(parts(1).trim().split(' ').map(java.lang.Double.parseDouble)) + LabeledPoint(label, features) + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index c8e160d00c2d6..69299c219878c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -129,7 +129,8 @@ object LinearDataGenerator { val sc = new SparkContext(sparkMaster, "LinearDataGenerator") val data = generateLinearRDD(sc, nexamples, nfeatures, eps, nparts = parts) - MLUtils.saveLabeledData(data, outputPath) + data.saveAsTextFile(outputPath) + sc.stop() } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala index c82cd8fd4641c..9d802678c4a77 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala @@ -79,7 +79,8 @@ object LogisticRegressionDataGenerator { val sc = new SparkContext(sparkMaster, "LogisticRegressionDataGenerator") val data = generateLogisticRDD(sc, nexamples, nfeatures, eps, parts) - MLUtils.saveLabeledData(data, outputPath) + data.saveAsTextFile(outputPath) + sc.stop() } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index e598b6cb171a8..aaf92a1a8869a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.rdd.PartitionwiseSampledRDD import org.apache.spark.util.random.BernoulliSampler -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.{LabeledPointParser, LabeledPoint} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.storage.StorageLevel @@ -180,7 +180,39 @@ object MLUtils { } /** - * :: Experimental :: + * Loads vectors saved using `RDD[Vector].saveAsTextFile`. + * @param sc Spark context + * @param path file or directory path in any Hadoop-supported file system URI + * @param minPartitions min number of partitions + * @return vectors stored as an RDD[Vector] + */ + def loadVectors(sc: SparkContext, path: String, minPartitions: Int): RDD[Vector] = + sc.textFile(path, minPartitions).map(Vectors.parse) + + /** + * Loads vectors saved using `RDD[Vector].saveAsTextFile` with the default number of partitions. + */ + def loadVectors(sc: SparkContext, path: String): RDD[Vector] = + sc.textFile(path, sc.defaultMinPartitions).map(Vectors.parse) + + /** + * Loads labeled points saved using `RDD[LabeledPoint].saveAsTextFile`. + * @param sc Spark context + * @param path file or directory path in any Hadoop-supported file system URI + * @param minPartitions min number of partitions + * @return labeled points stored as an RDD[LabeledPoint] + */ + def loadLabeledPoints(sc: SparkContext, path: String, minPartitions: Int): RDD[LabeledPoint] = + sc.textFile(path, minPartitions).map(LabeledPointParser.parse) + + /** + * Loads labeled points saved using `RDD[LabeledPoint].saveAsTextFile` with the default number of + * partitions. + */ + def loadLabeledPoints(sc: SparkContext, dir: String): RDD[LabeledPoint] = + loadLabeledPoints(sc, dir, sc.defaultMinPartitions) + + /** * Load labeled data from a file. The data format used here is * , ... * where , are feature values in Double and is the corresponding label as Double. @@ -189,8 +221,11 @@ object MLUtils { * @param dir Directory to the input data files. * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is * the label, and the second element represents the feature values (an array of Double). + * + * @deprecated Should use [[org.apache.spark.rdd.RDD#saveAsTextFile]] for saving and + * [[org.apache.spark.mllib.util.MLUtils#loadLabeledPoints]] for loading. */ - @Experimental + @deprecated("Should use MLUtils.loadLabeledPoints instead.", "1.0.1") def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { sc.textFile(dir).map { line => val parts = line.split(',') @@ -201,15 +236,17 @@ object MLUtils { } /** - * :: Experimental :: * Save labeled data to a file. The data format used here is * , ... * where , are feature values in Double and is the corresponding label as Double. * * @param data An RDD of LabeledPoints containing data to be saved. * @param dir Directory to save the data. + * + * @deprecated Should use [[org.apache.spark.rdd.RDD#saveAsTextFile]] for saving and + * [[org.apache.spark.mllib.util.MLUtils#loadLabeledPoints]] for loading. */ - @Experimental + @deprecated("Should use RDD[LabeledPoint].saveAsTextFile instead.", "1.0.1") def saveLabeledData(data: RDD[LabeledPoint], dir: String) { val dataStr = data.map(x => x.label + "," + x.features.toArray.mkString(" ")) dataStr.saveAsTextFile(dir) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala new file mode 100644 index 0000000000000..f7cba6c6cb628 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import java.util.StringTokenizer + +import scala.collection.mutable.{ArrayBuffer, ListBuffer} + +import org.apache.spark.SparkException + +/** + * Simple parser for a numeric structure consisting of three types: + * + * - number: a double in Java's floating number format + * - array: an array of numbers stored as `[v0,v1,...,vn]` + * - tuple: a list of numbers, arrays, or tuples stored as `(...)` + */ +private[mllib] object NumericParser { + + /** Parses a string into a Double, an Array[Double], or a Seq[Any]. */ + def parse(s: String): Any = { + val tokenizer = new StringTokenizer(s, "()[],", true) + if (tokenizer.hasMoreTokens()) { + val token = tokenizer.nextToken() + if (token == "(") { + parseTuple(tokenizer) + } else if (token == "[") { + parseArray(tokenizer) + } else { + // expecting a number + parseDouble(token) + } + } else { + throw new SparkException(s"Cannot find any token from the input string.") + } + } + + private def parseArray(tokenizer: StringTokenizer): Array[Double] = { + val values = ArrayBuffer.empty[Double] + var parsing = true + var allowComma = false + var token: String = null + while (parsing && tokenizer.hasMoreTokens()) { + token = tokenizer.nextToken() + if (token == "]") { + parsing = false + } else if (token == ",") { + if (allowComma) { + allowComma = false + } else { + throw new SparkException("Found a ',' at a wrong position.") + } + } else { + // expecting a number + values.append(parseDouble(token)) + allowComma = true + } + } + if (parsing) { + throw new SparkException(s"An array must end with ']'.") + } + values.toArray + } + + private def parseTuple(tokenizer: StringTokenizer): Seq[_] = { + val items = ListBuffer.empty[Any] + var parsing = true + var allowComma = false + var token: String = null + while (parsing && tokenizer.hasMoreTokens()) { + token = tokenizer.nextToken() + if (token == "(") { + items.append(parseTuple(tokenizer)) + allowComma = true + } else if (token == "[") { + items.append(parseArray(tokenizer)) + allowComma = true + } else if (token == ",") { + if (allowComma) { + allowComma = false + } else { + throw new SparkException("Found a ',' at a wrong position.") + } + } else if (token == ")") { + parsing = false + } else { + // expecting a number + items.append(parseDouble(token)) + allowComma = true + } + } + if (parsing) { + throw new SparkException(s"A tuple must end with ')'.") + } + items + } + + private def parseDouble(s: String): Double = { + try { + java.lang.Double.parseDouble(s) + } catch { + case e: Throwable => + throw new SparkException(s"Cannot parse a double from: $s", e) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala index ba8190b0e07e8..7db97e6bac688 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala @@ -65,7 +65,7 @@ object SVMDataGenerator { LabeledPoint(y, Vectors.dense(x)) } - MLUtils.saveLabeledData(data, outputPath) + data.saveAsTextFile(outputPath) sc.stop() } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala new file mode 100644 index 0000000000000..642843f90204c --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint + +class PythonMLLibAPISuite extends FunSuite { + val py = new PythonMLLibAPI + + test("vector serialization") { + val vectors = Seq( + Vectors.dense(Array.empty[Double]), + Vectors.dense(0.0), + Vectors.dense(0.0, -2.0), + Vectors.sparse(0, Array.empty[Int], Array.empty[Double]), + Vectors.sparse(1, Array.empty[Int], Array.empty[Double]), + Vectors.sparse(2, Array(1), Array(-2.0))) + vectors.foreach { v => + val bytes = py.serializeDoubleVector(v) + val u = py.deserializeDoubleVector(bytes) + assert(u.getClass === v.getClass) + assert(u === v) + } + } + + test("labeled point serialization") { + val points = Seq( + LabeledPoint(0.0, Vectors.dense(Array.empty[Double])), + LabeledPoint(1.0, Vectors.dense(0.0)), + LabeledPoint(-0.5, Vectors.dense(0.0, -2.0)), + LabeledPoint(0.0, Vectors.sparse(0, Array.empty[Int], Array.empty[Double])), + LabeledPoint(1.0, Vectors.sparse(1, Array.empty[Int], Array.empty[Double])), + LabeledPoint(-0.5, Vectors.sparse(2, Array(1), Array(-2.0)))) + points.foreach { p => + val bytes = py.serializeLabeledPoint(p) + val q = py.deserializeLabeledPoint(bytes) + assert(q.label === p.label) + assert(q.features.getClass === p.features.getClass) + assert(q.features === p.features) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index cfe8a27fcb71e..7972ceea1fe8a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.mllib.linalg import org.scalatest.FunSuite +import org.apache.spark.SparkException + class VectorsSuite extends FunSuite { val arr = Array(0.1, 0.0, 0.3, 0.4) @@ -100,4 +102,27 @@ class VectorsSuite extends FunSuite { assert(vec2(6) === 4.0) assert(vec2(7) === 0.0) } + + test("parse vectors") { + val vectors = Seq( + Vectors.dense(Array.empty[Double]), + Vectors.dense(1.0), + Vectors.dense(1.0E6, 0.0, -2.0e-7), + Vectors.sparse(0, Array.empty[Int], Array.empty[Double]), + Vectors.sparse(1, Array(0), Array(1.0)), + Vectors.sparse(3, Array(0, 2), Array(1.0, -2.0))) + vectors.foreach { v => + val v1 = Vectors.parse(v.toString) + assert(v.getClass === v1.getClass) + assert(v === v1) + } + + val malformatted = Seq("1", "[1,,]", "[1,2b]", "(1,[1,2])", "([1],[2.0,1.0])") + malformatted.foreach { s => + intercept[SparkException] { + Vectors.parse(s) + println(s"Didn't detect malformatted string $s.") + } + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala new file mode 100644 index 0000000000000..d9308aaba6ee1 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.Vectors + +class LabeledPointSuite extends FunSuite { + + test("parse labeled points") { + val points = Seq( + LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(0.0, Vectors.sparse(2, Array(1), Array(-1.0)))) + points.foreach { p => + assert(p === LabeledPointParser.parse(p.toString)) + } + } + + test("parse labeled points with v0.9 format") { + val point = LabeledPointParser.parse("1.0,1.0 0.0 -2.0") + assert(point === LabeledPoint(1.0, Vectors.dense(1.0, 0.0, -2.0))) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 3d05fb68988c8..c14870fb969a8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -160,5 +160,33 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { } } -} + test("loadVectors") { + val vectors = sc.parallelize(Seq( + Vectors.dense(1.0, 2.0), + Vectors.sparse(2, Array(1), Array(-1.0)), + Vectors.dense(0.0, 1.0) + ), 2) + val tempDir = Files.createTempDir() + val outputDir = new File(tempDir, "vectors") + val path = outputDir.toURI.toString + vectors.saveAsTextFile(path) + val loaded = loadVectors(sc, path) + assert(vectors.collect().toSet === loaded.collect().toSet) + Utils.deleteRecursively(tempDir) + } + test("loadLabeledPoints") { + val points = sc.parallelize(Seq( + LabeledPoint(1.0, Vectors.dense(1.0, 2.0)), + LabeledPoint(0.0, Vectors.sparse(2, Array(1), Array(-1.0))), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0)) + ), 2) + val tempDir = Files.createTempDir() + val outputDir = new File(tempDir, "points") + val path = outputDir.toURI.toString + points.saveAsTextFile(path) + val loaded = loadLabeledPoints(sc, path) + assert(points.collect().toSet === loaded.collect().toSet) + Utils.deleteRecursively(tempDir) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala new file mode 100644 index 0000000000000..f68fb95eac4e4 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import org.scalatest.FunSuite + +import org.apache.spark.SparkException + +class NumericParserSuite extends FunSuite { + + test("parser") { + val s = "((1.0,2e3),-4,[5e-6,7.0E8],+9)" + val parsed = NumericParser.parse(s).asInstanceOf[Seq[_]] + assert(parsed(0).asInstanceOf[Seq[_]] === Seq(1.0, 2.0e3)) + assert(parsed(1).asInstanceOf[Double] === -4.0) + assert(parsed(2).asInstanceOf[Array[Double]] === Array(5.0e-6, 7.0e8)) + assert(parsed(3).asInstanceOf[Double] === 9.0) + + val malformatted = Seq("a", "[1,,]", "0.123.4", "1 2", "3+4") + malformatted.foreach { s => + intercept[SparkException] { + NumericParser.parse(s) + println(s"Didn't detect malformatted string $s.") + } + } + } +} diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index 802a27a8da14d..a411a5d5914e0 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -22,6 +22,7 @@ from pyspark.mllib.linalg import SparseVector from pyspark.serializers import Serializer + """ Common utilities shared throughout MLlib, primarily for dealing with different data types. These include: @@ -147,7 +148,7 @@ def _serialize_sparse_vector(v): return ba -def _deserialize_double_vector(ba): +def _deserialize_double_vector(ba, offset=0): """Deserialize a double vector from a mutually understood format. >>> x = array([1.0, 2.0, 3.0, 4.0, -1.0, 0.0, -0.0]) @@ -160,43 +161,46 @@ def _deserialize_double_vector(ba): if type(ba) != bytearray: raise TypeError("_deserialize_double_vector called on a %s; " "wanted bytearray" % type(ba)) - if len(ba) < 5: + nb = len(ba) - offset + if nb < 5: raise TypeError("_deserialize_double_vector called on a %d-byte array, " - "which is too short" % len(ba)) - if ba[0] == DENSE_VECTOR_MAGIC: - return _deserialize_dense_vector(ba) - elif ba[0] == SPARSE_VECTOR_MAGIC: - return _deserialize_sparse_vector(ba) + "which is too short" % nb) + if ba[offset] == DENSE_VECTOR_MAGIC: + return _deserialize_dense_vector(ba, offset) + elif ba[offset] == SPARSE_VECTOR_MAGIC: + return _deserialize_sparse_vector(ba, offset) else: raise TypeError("_deserialize_double_vector called on bytearray " "with wrong magic") -def _deserialize_dense_vector(ba): +def _deserialize_dense_vector(ba, offset=0): """Deserialize a dense vector into a numpy array.""" - if len(ba) < 5: + nb = len(ba) - offset + if nb < 5: raise TypeError("_deserialize_dense_vector called on a %d-byte array, " - "which is too short" % len(ba)) - length = ndarray(shape=[1], buffer=ba, offset=1, dtype=int32)[0] - if len(ba) != 8 * length + 5: + "which is too short" % nb) + length = ndarray(shape=[1], buffer=ba, offset=offset + 1, dtype=int32)[0] + if nb < 8 * length + 5: raise TypeError("_deserialize_dense_vector called on bytearray " "with wrong length") - return _deserialize_numpy_array([length], ba, 5) + return _deserialize_numpy_array([length], ba, offset + 5) -def _deserialize_sparse_vector(ba): +def _deserialize_sparse_vector(ba, offset=0): """Deserialize a sparse vector into a MLlib SparseVector object.""" - if len(ba) < 9: + nb = len(ba) - offset + if nb < 9: raise TypeError("_deserialize_sparse_vector called on a %d-byte array, " - "which is too short" % len(ba)) - header = ndarray(shape=[2], buffer=ba, offset=1, dtype=int32) + "which is too short" % nb) + header = ndarray(shape=[2], buffer=ba, offset=offset + 1, dtype=int32) size = header[0] nonzeros = header[1] - if len(ba) != 9 + 12 * nonzeros: + if nb < 9 + 12 * nonzeros: raise TypeError("_deserialize_sparse_vector called on bytearray " "with wrong length") - indices = _deserialize_numpy_array([nonzeros], ba, 9, dtype=int32) - values = _deserialize_numpy_array([nonzeros], ba, 9 + 4 * nonzeros, dtype=float64) + indices = _deserialize_numpy_array([nonzeros], ba, offset + 9, dtype=int32) + values = _deserialize_numpy_array([nonzeros], ba, offset + 9 + 4 * nonzeros, dtype=float64) return SparseVector(int(size), indices, values) @@ -243,7 +247,23 @@ def _deserialize_double_matrix(ba): def _serialize_labeled_point(p): - """Serialize a LabeledPoint with a features vector of any type.""" + """ + Serialize a LabeledPoint with a features vector of any type. + + >>> from pyspark.mllib.regression import LabeledPoint + >>> dp0 = LabeledPoint(0.5, array([1.0, 2.0, 3.0, 4.0, -1.0, 0.0, -0.0])) + >>> dp1 = _deserialize_labeled_point(_serialize_labeled_point(dp0)) + >>> dp1.label == dp0.label + True + >>> array_equal(dp1.features, dp0.features) + True + >>> sp0 = LabeledPoint(0.0, SparseVector(4, [1, 3], [3.0, 5.5])) + >>> sp1 = _deserialize_labeled_point(_serialize_labeled_point(sp0)) + >>> sp1.label == sp1.label + True + >>> sp1.features == sp0.features + True + """ from pyspark.mllib.regression import LabeledPoint serialized_features = _serialize_double_vector(p.features) header = bytearray(9) @@ -252,6 +272,16 @@ def _serialize_labeled_point(p): header_float[0] = p.label return header + serialized_features +def _deserialize_labeled_point(ba, offset=0): + """Deserialize a LabeledPoint from a mutually understood format.""" + from pyspark.mllib.regression import LabeledPoint + if type(ba) != bytearray: + raise TypeError("Expecting a bytearray but got %s" % type(ba)) + if ba[offset] != LABELED_POINT_MAGIC: + raise TypeError("Expecting magic number %d but got %d" % (LABELED_POINT_MAGIC, ba[0])) + label = ndarray(shape=[1], buffer=ba, offset=offset + 1, dtype=float64)[0] + features = _deserialize_double_vector(ba, offset + 9) + return LabeledPoint(label, features) def _copyto(array, buffer, offset, shape, dtype): """ diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 276684272068b..db39ed0acdb66 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -43,11 +43,11 @@ def __init__(self, size, *args): or two sorted lists containing indices and values. >>> print SparseVector(4, {1: 1.0, 3: 5.5}) - [1: 1.0, 3: 5.5] + (4,[1,3],[1.0,5.5]) >>> print SparseVector(4, [(1, 1.0), (3, 5.5)]) - [1: 1.0, 3: 5.5] + (4,[1,3],[1.0,5.5]) >>> print SparseVector(4, [1, 3], [1.0, 5.5]) - [1: 1.0, 3: 5.5] + (4,[1,3],[1.0,5.5]) """ self.size = int(size) assert 1 <= len(args) <= 2, "must pass either 2 or 3 arguments" @@ -160,10 +160,9 @@ def squared_distance(self, other): return result def __str__(self): - inds = self.indices - vals = self.values - entries = ", ".join(["{0}: {1}".format(inds[i], vals[i]) for i in xrange(len(inds))]) - return "[" + entries + "]" + inds = "[" + ",".join([str(i) for i in self.indices]) + "]" + vals = "[" + ",".join([str(v) for v in self.values]) + "]" + return "(" + ",".join((str(self.size), inds, vals)) + ")" def __repr__(self): inds = self.indices @@ -213,11 +212,11 @@ def sparse(size, *args): or two sorted lists containing indices and values. >>> print Vectors.sparse(4, {1: 1.0, 3: 5.5}) - [1: 1.0, 3: 5.5] + (4,[1,3],[1.0,5.5]) >>> print Vectors.sparse(4, [(1, 1.0), (3, 5.5)]) - [1: 1.0, 3: 5.5] + (4,[1,3],[1.0,5.5]) >>> print Vectors.sparse(4, [1, 3], [1.0, 5.5]) - [1: 1.0, 3: 5.5] + (4,[1,3],[1.0,5.5]) """ return SparseVector(size, *args) @@ -232,6 +231,21 @@ def dense(elements): """ return array(elements, dtype=float64) + @staticmethod + def stringify(vector): + """ + Converts a vector into a string, which can be recognized by + Vectors.parse(). + + >>> Vectors.stringify(Vectors.sparse(2, [1], [1.0])) + '(2,[1],[1.0])' + >>> Vectors.stringify(Vectors.dense([0.0, 1.0])) + '[0.0,1.0]' + """ + if type(vector) == SparseVector: + return str(vector) + else: + return "[" + ",".join([str(v) for v in vector]) + "]" def _test(): import doctest diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index bc7de6d2e8958..b84bc531dec8c 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -23,7 +23,7 @@ _serialize_double_vector, _deserialize_double_vector, \ _get_initial_weights, _serialize_rating, _regression_train_wrapper, \ _linear_predictor_typecheck, _have_scipy, _scipy_issparse -from pyspark.mllib.linalg import SparseVector +from pyspark.mllib.linalg import SparseVector, Vectors class LabeledPoint(object): @@ -44,6 +44,9 @@ def __init__(self, label, features): else: raise TypeError("Expected NumPy array, list, SparseVector, or scipy.sparse matrix") + def __str__(self): + return "(" + ",".join((str(self.label), Vectors.stringify(self.features))) + ")" + class LinearModel(object): """A linear model that has a vector of coefficients and an intercept.""" diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 0e5f4520b9402..e24c144f458bd 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -19,7 +19,10 @@ from pyspark.mllib.linalg import Vectors, SparseVector from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib._common import _convert_vector +from pyspark.mllib._common import _convert_vector, _deserialize_labeled_point +from pyspark.rdd import RDD +from pyspark.serializers import NoOpSerializer + class MLUtils: @@ -105,24 +108,18 @@ def loadLibSVMFile(sc, path, multiclass=False, numFeatures=-1, minPartitions=Non >>> examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect() >>> multiclass_examples = MLUtils.loadLibSVMFile(sc, tempFile.name, True).collect() >>> tempFile.close() - >>> examples[0].label - 1.0 - >>> examples[0].features.size - 6 - >>> print examples[0].features - [0: 1.0, 2: 2.0, 4: 3.0] - >>> examples[1].label - 0.0 - >>> examples[1].features.size - 6 - >>> print examples[1].features - [] - >>> examples[2].label - 0.0 - >>> examples[2].features.size - 6 - >>> print examples[2].features - [1: 4.0, 3: 5.0, 5: 6.0] + >>> type(examples[0]) == LabeledPoint + True + >>> print examples[0] + (1.0,(6,[0,2,4],[1.0,2.0,3.0])) + >>> type(examples[1]) == LabeledPoint + True + >>> print examples[1] + (0.0,(6,[],[])) + >>> type(examples[2]) == LabeledPoint + True + >>> print examples[2] + (0.0,(6,[1,3,5],[4.0,5.0,6.0])) >>> multiclass_examples[1].label -1.0 """ @@ -158,6 +155,40 @@ def saveAsLibSVMFile(data, dir): lines.saveAsTextFile(dir) + @staticmethod + def loadLabeledPoints(sc, path, minPartitions=None): + """ + Load labeled points saved using RDD.saveAsTextFile. + + @param sc: Spark context + @param path: file or directory path in any Hadoop-supported file + system URI + @param minPartitions: min number of partitions + @return: labeled data stored as an RDD of LabeledPoint + + >>> from tempfile import NamedTemporaryFile + >>> from pyspark.mllib.util import MLUtils + >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, -1.23), (2, 4.56e-7)])), \ + LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] + >>> tempFile = NamedTemporaryFile(delete=True) + >>> tempFile.close() + >>> sc.parallelize(examples, 1).saveAsTextFile(tempFile.name) + >>> loaded = MLUtils.loadLabeledPoints(sc, tempFile.name).collect() + >>> type(loaded[0]) == LabeledPoint + True + >>> print examples[0] + (1.1,(3,[0,2],[-1.23,4.56e-07])) + >>> type(examples[1]) == LabeledPoint + True + >>> print examples[1] + (0.0,[1.01,2.02,3.03]) + """ + minPartitions = minPartitions or min(sc.defaultParallelism, 2) + jSerialized = sc._jvm.PythonMLLibAPI().loadLabeledPoints(sc._jsc, path, minPartitions) + serialized = RDD(jSerialized, sc, NoOpSerializer()) + return serialized.map(lambda bytes: _deserialize_labeled_point(bytearray(bytes))) + + def _test(): import doctest from pyspark.context import SparkContext From 1765c8d0ddf6bb5bc3c21f994456eba04c581de4 Mon Sep 17 00:00:00 2001 From: Colin McCabe Date: Wed, 4 Jun 2014 15:56:29 -0700 Subject: [PATCH 03/18] SPARK-1518: FileLogger: Fix compile against Hadoop trunk In Hadoop trunk (currently Hadoop 3.0.0), the deprecated FSDataOutputStream#sync() method has been removed. Instead, we should call FSDataOutputStream#hflush, which does the same thing as the deprecated method used to do. Author: Colin McCabe Closes #898 from cmccabe/SPARK-1518 and squashes the following commits: 752b9d7 [Colin McCabe] FileLogger: Fix compile against Hadoop trunk --- .../scala/org/apache/spark/util/FileLogger.scala | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/FileLogger.scala b/core/src/main/scala/org/apache/spark/util/FileLogger.scala index 0e6d21b22023a..6a95dc06e155d 100644 --- a/core/src/main/scala/org/apache/spark/util/FileLogger.scala +++ b/core/src/main/scala/org/apache/spark/util/FileLogger.scala @@ -61,6 +61,14 @@ private[spark] class FileLogger( // Only defined if the file system scheme is not local private var hadoopDataStream: Option[FSDataOutputStream] = None + // The Hadoop APIs have changed over time, so we use reflection to figure out + // the correct method to use to flush a hadoop data stream. See SPARK-1518 + // for details. + private val hadoopFlushMethod = { + val cls = classOf[FSDataOutputStream] + scala.util.Try(cls.getMethod("hflush")).getOrElse(cls.getMethod("sync")) + } + private var writer: Option[PrintWriter] = None /** @@ -149,13 +157,13 @@ private[spark] class FileLogger( /** * Flush the writer to disk manually. * - * If the Hadoop FileSystem is used, the underlying FSDataOutputStream (r1.0.4) must be - * sync()'ed manually as it does not support flush(), which is invoked by when higher - * level streams are flushed. + * When using a Hadoop filesystem, we need to invoke the hflush or sync + * method. In HDFS, hflush guarantees that the data gets to all the + * DataNodes. */ def flush() { writer.foreach(_.flush()) - hadoopDataStream.foreach(_.sync()) + hadoopDataStream.foreach(hadoopFlushMethod.invoke(_)) } /** From 11ded3f66f178e4d8d2b23491dd5e0ea23bcf719 Mon Sep 17 00:00:00 2001 From: Varakhedi Sujeet Date: Wed, 4 Jun 2014 16:01:56 -0700 Subject: [PATCH 04/18] SPARK-1790: Update EC2 scripts to support r3 instance types Author: Varakhedi Sujeet Closes #960 from sujeetv/ec2-r3 and squashes the following commits: 3cb9fd5 [Varakhedi Sujeet] SPARK-1790: Update EC2 scripts to support r3 instance --- ec2/spark_ec2.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 3af9f66e17dc2..9d5748ba4bc23 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -230,7 +230,12 @@ def get_spark_ami(opts): "c3.xlarge": "pvm", "c3.2xlarge": "pvm", "c3.4xlarge": "pvm", - "c3.8xlarge": "pvm" + "c3.8xlarge": "pvm", + "r3.large": "hvm", + "r3.xlarge": "hvm", + "r3.2xlarge": "hvm", + "r3.4xlarge": "hvm", + "r3.8xlarge": "hvm" } if opts.instance_type in instance_types: instance_type = instance_types[opts.instance_type] @@ -538,7 +543,12 @@ def get_num_disks(instance_type): "c3.xlarge": 2, "c3.2xlarge": 2, "c3.4xlarge": 2, - "c3.8xlarge": 2 + "c3.8xlarge": 2, + "r3.large": 1, + "r3.xlarge": 1, + "r3.2xlarge": 1, + "r3.4xlarge": 1, + "r3.8xlarge": 2 } if instance_type in disks_by_instance: return disks_by_instance[instance_type] From abea2d4ff099036c67fc73136d0e61d0d0e22123 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Wed, 4 Jun 2014 16:45:53 -0700 Subject: [PATCH 05/18] Minor: Fix documentation error from apache/spark#946 Author: Ankur Dave Closes #970 from ankurdave/SPARK-1991_docfix and squashes the following commits: 6d07343 [Ankur Dave] Minor: Fix documentation error from apache/spark#946 --- .../src/main/scala/org/apache/spark/graphx/GraphLoader.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala index 2e814e34f9ad8..f4c79365b16da 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala @@ -49,8 +49,8 @@ object GraphLoader extends Logging { * @param canonicalOrientation whether to orient edges in the positive * direction * @param minEdgePartitions the number of partitions for the edge RDD - * @param edgeStorageLevel the desired storage level for the edge partitions. To set the vertex - * storage level, call [[org.apache.spark.graphx.Graph#persistVertices]]. + * @param edgeStorageLevel the desired storage level for the edge partitions + * @param vertexStorageLevel the desired storage level for the vertex partitions */ def edgeListFile( sc: SparkContext, From b77c19be053125fde99b098ec1e1162f25b5433c Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 4 Jun 2014 22:56:49 -0700 Subject: [PATCH 06/18] Fix issue in ReplSuite with hadoop-provided profile. When building the assembly with the maven "hadoop-provided" profile, the executors were failing to come up because Hadoop classes were not found in the classpath anymore; so add them explicitly to the classpath using spark.executor.extraClassPath. This is only needed for the local-cluster mode, but doesn't affect other tests, so it's added for all of them to keep the code simpler. Author: Marcelo Vanzin Closes #781 from vanzin/repl-test-fix and squashes the following commits: 4f0a3b0 [Marcelo Vanzin] Fix issue in ReplSuite with hadoop-provided profile. --- .../scala/org/apache/spark/repl/ReplSuite.scala | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 98cdfd0054713..7c765edd55027 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -32,6 +32,8 @@ import org.apache.spark.util.Utils class ReplSuite extends FunSuite { def runInterpreter(master: String, input: String): String = { + val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath" + val in = new BufferedReader(new StringReader(input + "\n")) val out = new StringWriter() val cl = getClass.getClassLoader @@ -44,13 +46,23 @@ class ReplSuite extends FunSuite { } } } + val classpath = paths.mkString(File.pathSeparator) + + val oldExecutorClasspath = System.getProperty(CONF_EXECUTOR_CLASSPATH) + System.setProperty(CONF_EXECUTOR_CLASSPATH, classpath) + val interp = new SparkILoop(in, new PrintWriter(out), master) org.apache.spark.repl.Main.interp = interp - interp.process(Array("-classpath", paths.mkString(File.pathSeparator))) + interp.process(Array("-classpath", classpath)) org.apache.spark.repl.Main.interp = null if (interp.sparkContext != null) { interp.sparkContext.stop() } + if (oldExecutorClasspath != null) { + System.setProperty(CONF_EXECUTOR_CLASSPATH, oldExecutorClasspath) + } else { + System.clearProperty(CONF_EXECUTOR_CLASSPATH) + } return out.toString } From 7c160293d6d708718d566e700cfb407a31280b89 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 5 Jun 2014 11:27:33 -0700 Subject: [PATCH 07/18] [SPARK-2029] Bump pom.xml version number of master branch to 1.1.0-SNAPSHOT. Author: Takuya UESHIN Closes #974 from ueshin/issues/SPARK-2029 and squashes the following commits: e19e8f4 [Takuya UESHIN] Bump version number to 1.1.0-SNAPSHOT. --- assembly/pom.xml | 2 +- bagel/pom.xml | 2 +- core/pom.xml | 2 +- examples/pom.xml | 2 +- external/flume/pom.xml | 2 +- external/kafka/pom.xml | 2 +- external/mqtt/pom.xml | 2 +- external/twitter/pom.xml | 2 +- external/zeromq/pom.xml | 2 +- extras/java8-tests/pom.xml | 2 +- extras/spark-ganglia-lgpl/pom.xml | 2 +- graphx/pom.xml | 2 +- mllib/pom.xml | 2 +- pom.xml | 2 +- repl/pom.xml | 2 +- sql/catalyst/pom.xml | 2 +- sql/core/pom.xml | 2 +- sql/hive/pom.xml | 2 +- streaming/pom.xml | 2 +- tools/pom.xml | 2 +- yarn/alpha/pom.xml | 2 +- yarn/pom.xml | 2 +- yarn/stable/pom.xml | 2 +- 23 files changed, 23 insertions(+), 23 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index 963357b9ab167..0c60b66c3daca 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.0.0-SNAPSHOT + 1.1.0-SNAPSHOT ../pom.xml diff --git a/bagel/pom.xml b/bagel/pom.xml index 355f437c5b16a..c8e39a415af28 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.0.0-SNAPSHOT + 1.1.0-SNAPSHOT ../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 0777c5b1f03d4..0c746175afa73 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.0.0-SNAPSHOT + 1.1.0-SNAPSHOT ../pom.xml diff --git a/examples/pom.xml b/examples/pom.xml index 874bcd7916f35..4f6d7fdb87d47 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.0.0-SNAPSHOT + 1.1.0-SNAPSHOT ../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 6aec215687fe0..c1f581967777b 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.0.0-SNAPSHOT + 1.1.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 979eb0ca624bd..d014a7aad0fca 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.0.0-SNAPSHOT + 1.1.0-SNAPSHOT ../../pom.xml diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 7b2dc5ba1d7f9..4980208cba3b0 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.0.0-SNAPSHOT + 1.1.0-SNAPSHOT ../../pom.xml diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 5766d3a0d44ec..7073bd4404d9c 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.0.0-SNAPSHOT + 1.1.0-SNAPSHOT ../../pom.xml diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 4ed4196bd8662..cf306e0dca8bd 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.0.0-SNAPSHOT + 1.1.0-SNAPSHOT ../../pom.xml diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 602f66f9c5cf1..955ec1a8c3033 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.0.0-SNAPSHOT + 1.1.0-SNAPSHOT ../../pom.xml diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index 11ac827ed54a0..22ea330b4374d 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.0.0-SNAPSHOT + 1.1.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index dc108d2fe7fbd..7d5d83e7f3bb9 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.0.0-SNAPSHOT + 1.1.0-SNAPSHOT ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index cdd33dbb7970d..4aae2026dcaf2 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.0.0-SNAPSHOT + 1.1.0-SNAPSHOT ../pom.xml diff --git a/pom.xml b/pom.xml index fcd6f66b4414a..87c8e29ad1069 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent - 1.0.0-SNAPSHOT + 1.1.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ diff --git a/repl/pom.xml b/repl/pom.xml index bcdb24b040cc8..4a66408ef3d2d 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.0.0-SNAPSHOT + 1.1.0-SNAPSHOT ../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 8d2e4baf69e30..6c78c34486010 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent - 1.0.0-SNAPSHOT + 1.1.0-SNAPSHOT ../../pom.xml diff --git a/sql/core/pom.xml b/sql/core/pom.xml index fb3b190b4ec5a..e65ca6be485e3 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent - 1.0.0-SNAPSHOT + 1.1.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 9254b70e64a08..5ede76e5c3904 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent - 1.0.0-SNAPSHOT + 1.1.0-SNAPSHOT ../../pom.xml diff --git a/streaming/pom.xml b/streaming/pom.xml index 6435224a14674..f506d6ce34a6f 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.0.0-SNAPSHOT + 1.1.0-SNAPSHOT ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index 1875c497bc61c..79cd8551d0722 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.0.0-SNAPSHOT + 1.1.0-SNAPSHOT ../pom.xml diff --git a/yarn/alpha/pom.xml b/yarn/alpha/pom.xml index e076ca1d44b97..b8a631dd0bb3b 100644 --- a/yarn/alpha/pom.xml +++ b/yarn/alpha/pom.xml @@ -20,7 +20,7 @@ org.apache.spark yarn-parent_2.10 - 1.0.0-SNAPSHOT + 1.1.0-SNAPSHOT ../pom.xml diff --git a/yarn/pom.xml b/yarn/pom.xml index 2811ffffbdfa2..ef7066ef1fdfc 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.0.0-SNAPSHOT + 1.1.0-SNAPSHOT ../pom.xml diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml index 0780f251b595c..0931beb505508 100644 --- a/yarn/stable/pom.xml +++ b/yarn/stable/pom.xml @@ -20,7 +20,7 @@ org.apache.spark yarn-parent_2.10 - 1.0.0-SNAPSHOT + 1.1.0-SNAPSHOT ../pom.xml From 89cdbb087cb2f0d03be2dd77440300c6bd61c792 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Thu, 5 Jun 2014 11:39:35 -0700 Subject: [PATCH 08/18] SPARK-1677: allow user to disable output dir existence checking https://issues.apache.org/jira/browse/SPARK-1677 For compatibility with older versions of Spark it would be nice to have an option `spark.hadoop.validateOutputSpecs` (default true) for the user to disable the output directory existence checking Author: CodingCat Closes #947 from CodingCat/SPARK-1677 and squashes the following commits: 7930f83 [CodingCat] miao c0c0e03 [CodingCat] bug fix and doc update 5318562 [CodingCat] bug fix 13219b5 [CodingCat] allow user to disable output dir existence checking --- .../apache/spark/rdd/PairRDDFunctions.scala | 6 +++-- .../scala/org/apache/spark/FileSuite.scala | 22 +++++++++++++++++++ docs/configuration.md | 8 +++++++ 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index f2ce3cbd47f93..8909980957058 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -737,7 +737,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val outfmt = job.getOutputFormatClass val jobFormat = outfmt.newInstance - if (jobFormat.isInstanceOf[NewFileOutputFormat[_, _]]) { + if (self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true) && + jobFormat.isInstanceOf[NewFileOutputFormat[_, _]]) { // FileOutputFormat ignores the filesystem parameter jobFormat.checkOutputSpecs(job) } @@ -803,7 +804,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " + valueClass.getSimpleName + ")") - if (outputFormatInstance.isInstanceOf[FileOutputFormat[_, _]]) { + if (self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true) && + outputFormatInstance.isInstanceOf[FileOutputFormat[_, _]]) { // FileOutputFormat ignores the filesystem parameter val ignoredFs = FileSystem.get(conf) conf.getOutputFormat.checkOutputSpecs(ignoredFs, conf) diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 1f2206b1f0379..070e974657860 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -230,6 +230,17 @@ class FileSuite extends FunSuite with LocalSparkContext { } } + test ("allow user to disable the output directory existence checking (old Hadoop API") { + val sf = new SparkConf() + sf.setAppName("test").setMaster("local").set("spark.hadoop.validateOutputSpecs", "false") + sc = new SparkContext(sf) + val randomRDD = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 1) + randomRDD.saveAsTextFile(tempDir.getPath + "/output") + assert(new File(tempDir.getPath + "/output/part-00000").exists() === true) + randomRDD.saveAsTextFile(tempDir.getPath + "/output") + assert(new File(tempDir.getPath + "/output/part-00000").exists() === true) + } + test ("prevent user from overwriting the empty directory (new Hadoop API)") { sc = new SparkContext("local", "test") val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) @@ -248,6 +259,17 @@ class FileSuite extends FunSuite with LocalSparkContext { } } + test ("allow user to disable the output directory existence checking (new Hadoop API") { + val sf = new SparkConf() + sf.setAppName("test").setMaster("local").set("spark.hadoop.validateOutputSpecs", "false") + sc = new SparkContext(sf) + val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) + randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempDir.getPath + "/output") + assert(new File(tempDir.getPath + "/output/part-r-00000").exists() === true) + randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempDir.getPath + "/output") + assert(new File(tempDir.getPath + "/output/part-r-00000").exists() === true) + } + test ("save Hadoop Dataset through old Hadoop API") { sc = new SparkContext("local", "test") val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) diff --git a/docs/configuration.md b/docs/configuration.md index 0697f7fc2fd91..71fafa573467f 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -487,6 +487,14 @@ Apart from these, the following properties are also available, and may be useful this duration will be cleared as well. + + spark.hadoop.validateOutputSpecs + true + If set to true, validates the output specification (e.g. checking if the output directory already exists) + used in saveAsHadoopFile and other variants. This can be disabled to silence exceptions due to pre-existing + output directories. We recommend that users do not disable this except if trying to achieve compatibility with + previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand. + #### Networking From e4c11eef2f64df0b6a432f40b669486d91ca6352 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 5 Jun 2014 12:00:31 -0700 Subject: [PATCH 09/18] [SPARK-2036] [SQL] CaseConversionExpression should check if the evaluated value is null. `CaseConversionExpression` should check if the evaluated value is `null`. Author: Takuya UESHIN Closes #982 from ueshin/issues/SPARK-2036 and squashes the following commits: 61e1c54 [Takuya UESHIN] Add check if the evaluated value is null. --- .../catalyst/expressions/stringOperations.scala | 8 ++++++-- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 14 ++++++++++++++ .../test/scala/org/apache/spark/sql/TestData.scala | 8 ++++++++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index dcded0774180e..420303408451f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -81,8 +81,12 @@ trait CaseConversionExpression { def dataType: DataType = StringType override def eval(input: Row): Any = { - val converted = child.eval(input) - convert(converted.toString) + val evaluated = child.eval(input) + if (evaluated == null) { + null + } else { + convert(evaluated.toString) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 95860e6683f67..e2ad3915d3134 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -322,6 +322,13 @@ class SQLQuerySuite extends QueryTest { (2, "B"), (3, "C"), (4, "D"))) + + checkAnswer( + sql("SELECT n, UPPER(s) FROM nullStrings"), + Seq( + (1, "ABC"), + (2, "ABC"), + (3, null))) } test("system function lower()") { @@ -334,6 +341,13 @@ class SQLQuerySuite extends QueryTest { (4, "d"), (5, "e"), (6, "f"))) + + checkAnswer( + sql("SELECT n, LOWER(s) FROM nullStrings"), + Seq( + (1, "abc"), + (2, "abc"), + (3, null))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 944f520e43515..876bd1636aab3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -106,4 +106,12 @@ object TestData { NullInts(null) :: Nil ) nullInts.registerAsTable("nullInts") + + case class NullStrings(n: Int, s: String) + val nullStrings = + TestSQLContext.sparkContext.parallelize( + NullStrings(1, "abc") :: + NullStrings(2, "ABC") :: + NullStrings(3, null) :: Nil) + nullStrings.registerAsTable("nullStrings") } From f6143f127db59e7f5a00fd70605f85248869347d Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 5 Jun 2014 13:06:46 -0700 Subject: [PATCH 10/18] HOTFIX: Remove generated-mima-excludes file after runing MIMA. This has been causing some false failures on PR's that don't merge correctly. Author: Patrick Wendell Closes #971 from pwendell/mima and squashes the following commits: 1dc80aa [Patrick Wendell] HOTFIX: Remove generated-mima-excludes file after runing MIMA. --- dev/mima | 1 + 1 file changed, 1 insertion(+) diff --git a/dev/mima b/dev/mima index d4099990254cc..ab6bd4469b0e8 100755 --- a/dev/mima +++ b/dev/mima @@ -31,4 +31,5 @@ if [ $ret_val != 0 ]; then echo "NOTE: Exceptions to binary compatibility can be added in project/MimaExcludes.scala" fi +rm -f .generated-mima-excludes exit $ret_val From 5473aa7c02916022430493637b1492554b48c30b Mon Sep 17 00:00:00 2001 From: Kalpit Shah Date: Thu, 5 Jun 2014 13:07:26 -0700 Subject: [PATCH 11/18] sbt 0.13.X should be using sbt-assembly 0.11.X https://github.com/sbt/sbt-assembly/blob/master/README.md Author: Kalpit Shah Closes #555 from kalpit/upgrade/sbtassembly and squashes the following commits: 1fa7324 [Kalpit Shah] sbt 0.13.X should be using sbt-assembly 0.11.X --- project/plugins.sbt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/plugins.sbt b/project/plugins.sbt index 0cd16fd5bedd4..472819b9fb8ba 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -4,7 +4,7 @@ resolvers += Resolver.url("artifactory", url("http://scalasbt.artifactoryonline. resolvers += "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/" -addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.10.2") +addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2") addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.2.0") From 668cb1defe735add91f4a5b7b8ebe7cfd5640b25 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 5 Jun 2014 13:13:33 -0700 Subject: [PATCH 12/18] Remove compile-scoped junit dependency. This avoids having junit classes showing up in the assembly jar. I verified that only test classes in the jtransforms package use junit. Author: Marcelo Vanzin Closes #794 from vanzin/junit-dep-exclusion and squashes the following commits: 274e1c2 [Marcelo Vanzin] Remove junit from assembly in sbt build also. ad950be [Marcelo Vanzin] Remove compile-scoped junit dependency. --- mllib/pom.xml | 8 ++++++++ project/SparkBuild.scala | 3 ++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/mllib/pom.xml b/mllib/pom.xml index 4aae2026dcaf2..878cb83dbf783 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -50,6 +50,14 @@ org.scalanlp breeze_${scala.binary.version} 0.7 + + + + junit + junit + + org.scalatest diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index efb0b9319be13..d0049a8ac43aa 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -322,6 +322,7 @@ object SparkBuild extends Build { val excludeJruby = ExclusionRule(organization = "org.jruby") val excludeThrift = ExclusionRule(organization = "org.apache.thrift") val excludeServletApi = ExclusionRule(organization = "javax.servlet", artifact = "servlet-api") + val excludeJUnit = ExclusionRule(organization = "junit") def sparkPreviousArtifact(id: String, organization: String = "org.apache.spark", version: String = "1.0.0", crossVersion: String = "2.10"): Option[sbt.ModuleID] = { @@ -466,7 +467,7 @@ object SparkBuild extends Build { previousArtifact := sparkPreviousArtifact("spark-mllib"), libraryDependencies ++= Seq( "org.jblas" % "jblas" % jblasVersion, - "org.scalanlp" %% "breeze" % "0.7" + "org.scalanlp" %% "breeze" % "0.7" excludeAll(excludeJUnit) ) ) From c7a183b2c2bca13565496495b4ae3a3a9f63f9ab Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 5 Jun 2014 17:42:08 -0700 Subject: [PATCH 13/18] [SPARK-2041][SQL] Correctly analyze queries where columnName == tableName. Author: Michael Armbrust Closes #985 from marmbrus/tableName and squashes the following commits: 3caaa27 [Michael Armbrust] Correctly analyze queries where columnName == tableName. --- .../spark/sql/catalyst/plans/logical/LogicalPlan.scala | 3 ++- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 6 ++++++ sql/core/src/test/scala/org/apache/spark/sql/TestData.scala | 3 +++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 5eb52d5350f55..2b8fbdcde9d37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -64,7 +64,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { // struct fields. val options = children.flatMap(_.output).flatMap { option => // If the first part of the desired name matches a qualifier for this possible match, drop it. - val remainingParts = if (option.qualifiers contains parts.head) parts.drop(1) else parts + val remainingParts = + if (option.qualifiers.contains(parts.head) && parts.size > 1) parts.drop(1) else parts if (option.name == remainingParts.head) (option, remainingParts.tail.toList) :: Nil else Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index e2ad3915d3134..aa0c426f6fcb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -28,6 +28,12 @@ class SQLQuerySuite extends QueryTest { // Make sure the tables are loaded. TestData + test("SPARK-2041 column name equals tablename") { + checkAnswer( + sql("SELECT tableName FROM tableName"), + "test") + } + test("index into array") { checkAnswer( sql("SELECT data, data[0], data[0] + data[1], data[0 + 1] FROM arrayData"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 876bd1636aab3..05de736bbce1b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -114,4 +114,7 @@ object TestData { NullStrings(2, "ABC") :: NullStrings(3, null) :: Nil) nullStrings.registerAsTable("nullStrings") + + case class TableName(tableName: String) + TestSQLContext.sparkContext.parallelize(TableName("test") :: Nil).registerAsTable("tableName") } From 3d3f8c8004da110ca97973119e9d9f04f878ee81 Mon Sep 17 00:00:00 2001 From: CrazyJvm Date: Thu, 5 Jun 2014 17:44:46 -0700 Subject: [PATCH 14/18] Use pluggable clock in DAGSheduler #SPARK-2031 DAGScheduler supports pluggable clock like what TaskSetManager does. Author: CrazyJvm Closes #976 from CrazyJvm/clock and squashes the following commits: 6779a4c [CrazyJvm] Use pluggable clock in DAGSheduler --- .../org/apache/spark/scheduler/DAGScheduler.scala | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index ccff6a3d1aebc..e09a4221e8315 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -38,7 +38,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId} -import org.apache.spark.util.Utils +import org.apache.spark.util.{SystemClock, Clock, Utils} /** * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of @@ -61,7 +61,8 @@ class DAGScheduler( listenerBus: LiveListenerBus, mapOutputTracker: MapOutputTrackerMaster, blockManagerMaster: BlockManagerMaster, - env: SparkEnv) + env: SparkEnv, + clock: Clock = SystemClock) extends Logging { import DAGScheduler._ @@ -781,7 +782,7 @@ class DAGScheduler( logDebug("New pending tasks: " + myPending) taskScheduler.submitTasks( new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) - stageToInfos(stage).submissionTime = Some(System.currentTimeMillis()) + stageToInfos(stage).submissionTime = Some(clock.getTime()) } else { logDebug("Stage " + stage + " is actually done; %b %d %d".format( stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) @@ -807,11 +808,11 @@ class DAGScheduler( def markStageAsFinished(stage: Stage) = { val serviceTime = stageToInfos(stage).submissionTime match { - case Some(t) => "%.03f".format((System.currentTimeMillis() - t) / 1000.0) + case Some(t) => "%.03f".format((clock.getTime() - t) / 1000.0) case _ => "Unknown" } logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) - stageToInfos(stage).completionTime = Some(System.currentTimeMillis()) + stageToInfos(stage).completionTime = Some(clock.getTime()) listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage))) runningStages -= stage } @@ -1015,7 +1016,7 @@ class DAGScheduler( return } val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq - stageToInfos(failedStage).completionTime = Some(System.currentTimeMillis()) + stageToInfos(failedStage).completionTime = Some(clock.getTime()) for (resultStage <- dependentStages) { val job = resultStageToJob(resultStage) failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason", From 9bad0b73722fb359f14db864e69aa7efde3588c5 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Thu, 5 Jun 2014 17:45:38 -0700 Subject: [PATCH 15/18] [SPARK-2025] Unpersist edges of previous graph in Pregel Due to a bug introduced by apache/spark#497, Pregel does not unpersist replicated vertices from previous iterations. As a result, they stay cached until memory is full, wasting GC time. This PR corrects the problem by unpersisting both the edges and the replicated vertices of previous iterations. This is safe because the edges and replicated vertices of the current iteration are cached by the call to `g.cache()` and then materialized by the call to `messages.count()`. Therefore no unmaterialized RDDs depend on `prevG.edges`. I verified that no recomputation occurs by running PageRank with a custom patch to Spark that warns when a partition is recomputed. Thanks to Tim Weninger for reporting this bug. Author: Ankur Dave Closes #972 from ankurdave/SPARK-2025 and squashes the following commits: 13d5b07 [Ankur Dave] Unpersist edges of previous graph in Pregel --- graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 4572eab2875bb..5e55620147df8 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -150,6 +150,7 @@ object Pregel extends Logging { oldMessages.unpersist(blocking=false) newVerts.unpersist(blocking=false) prevG.unpersistVertices(blocking=false) + prevG.edges.unpersist(blocking=false) // count the iteration i += 1 } From b45c13e7d798f97b92f1a6329528191b8d779c4f Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Thu, 5 Jun 2014 23:01:48 -0700 Subject: [PATCH 16/18] SPARK-2043: ExternalAppendOnlyMap doesn't always find matching keys The current implementation reads one key with the next hash code as it finishes reading the keys with the current hash code, which may cause it to miss some matches of the next key. This can cause operations like join to give the wrong result when reduce tasks spill to disk and there are hash collisions, as values won't be matched together. This PR fixes it by not reading in that next key, using a peeking iterator instead. Author: Matei Zaharia Closes #986 from mateiz/spark-2043 and squashes the following commits: 0959514 [Matei Zaharia] Added unit test for having many hash collisions 892debb [Matei Zaharia] SPARK-2043: don't read a key with the next hash code in ExternalAppendOnlyMap, instead use a buffered iterator to only read values with the current hash code. --- .../collection/ExternalAppendOnlyMap.scala | 10 +++-- .../ExternalAppendOnlyMapSuite.scala | 39 ++++++++++++++++++- 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 170f09be21534..288badd3160f8 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -20,6 +20,7 @@ package org.apache.spark.util.collection import java.io.{InputStream, BufferedInputStream, FileInputStream, File, Serializable, EOFException} import java.util.Comparator +import scala.collection.BufferedIterator import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -231,7 +232,7 @@ class ExternalAppendOnlyMap[K, V, C]( // Input streams are derived both from the in-memory map and spilled maps on disk // The in-memory map is sorted in place, while the spilled maps are already in sorted order private val sortedMap = currentMap.destructiveSortedIterator(comparator) - private val inputStreams = Seq(sortedMap) ++ spilledMaps + private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered) inputStreams.foreach { it => val kcPairs = getMorePairs(it) @@ -246,13 +247,13 @@ class ExternalAppendOnlyMap[K, V, C]( * In the event of key hash collisions, this ensures no pairs are hidden from being merged. * Assume the given iterator is in sorted order. */ - private def getMorePairs(it: Iterator[(K, C)]): ArrayBuffer[(K, C)] = { + private def getMorePairs(it: BufferedIterator[(K, C)]): ArrayBuffer[(K, C)] = { val kcPairs = new ArrayBuffer[(K, C)] if (it.hasNext) { var kc = it.next() kcPairs += kc val minHash = kc._1.hashCode() - while (it.hasNext && kc._1.hashCode() == minHash) { + while (it.hasNext && it.head._1.hashCode() == minHash) { kc = it.next() kcPairs += kc } @@ -325,7 +326,8 @@ class ExternalAppendOnlyMap[K, V, C]( * * StreamBuffers are ordered by the minimum key hash found across all of their own pairs. */ - private case class StreamBuffer(iterator: Iterator[(K, C)], pairs: ArrayBuffer[(K, C)]) + private class StreamBuffer( + val iterator: BufferedIterator[(K, C)], val pairs: ArrayBuffer[(K, C)]) extends Comparable[StreamBuffer] { def isEmpty = pairs.length == 0 diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index cdebefb67510c..deb780953579d 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -277,6 +277,11 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { ("pomatoes", "eructation") // 568647356 ) + collisionPairs.foreach { case (w1, w2) => + // String.hashCode is documented to use a specific algorithm, but check just in case + assert(w1.hashCode === w2.hashCode) + } + (1 to 100000).map(_.toString).foreach { i => map.insert(i, i) } collisionPairs.foreach { case (w1, w2) => map.insert(w1, w2) @@ -296,7 +301,32 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { assert(kv._2.equals(expectedValue)) count += 1 } - assert(count == 100000 + collisionPairs.size * 2) + assert(count === 100000 + collisionPairs.size * 2) + } + + test("spilling with many hash collisions") { + val conf = new SparkConf(true) + conf.set("spark.shuffle.memoryFraction", "0.0001") + sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + + val map = new ExternalAppendOnlyMap[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _) + + // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes + // problems if the map fails to group together the objects with the same code (SPARK-2043). + for (i <- 1 to 10) { + for (j <- 1 to 10000) { + map.insert(FixedHashObject(j, j % 2), 1) + } + } + + val it = map.iterator + var count = 0 + while (it.hasNext) { + val kv = it.next() + assert(kv._2 === 10) + count += 1 + } + assert(count === 10000) } test("spilling with hash collisions using the Int.MaxValue key") { @@ -317,3 +347,10 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { } } } + +/** + * A dummy class that always returns the same hash code, to easily test hash collisions + */ +case class FixedHashObject(val v: Int, val h: Int) extends Serializable { + override def hashCode(): Int = h +} From 41db44c428a10f4453462d002d226798bb8fbdda Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 5 Jun 2014 23:20:59 -0700 Subject: [PATCH 17/18] [SPARK-2050][SQL] LIKE, RLIKE and IN in HQL should not be case sensitive. Author: Michael Armbrust Closes #989 from marmbrus/caseSensitiveFuncitons and squashes the following commits: 681de54 [Michael Armbrust] LIKE, RLIKE and IN in HQL should not be case sensitive. --- .../scala/org/apache/spark/sql/hive/HiveQl.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index e8a3ee5535b6e..c133bf2423190 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -791,6 +791,10 @@ private[hive] object HiveQl { val NOT = "(?i)NOT".r val TRUE = "(?i)TRUE".r val FALSE = "(?i)FALSE".r + val LIKE = "(?i)LIKE".r + val RLIKE = "(?i)RLIKE".r + val REGEXP = "(?i)REGEXP".r + val IN = "(?i)IN".r protected def nodeToExpr(node: Node): Expression = node match { /* Attribute References */ @@ -871,14 +875,14 @@ private[hive] object HiveQl { case Token(">=", left :: right:: Nil) => GreaterThanOrEqual(nodeToExpr(left), nodeToExpr(right)) case Token("<", left :: right:: Nil) => LessThan(nodeToExpr(left), nodeToExpr(right)) case Token("<=", left :: right:: Nil) => LessThanOrEqual(nodeToExpr(left), nodeToExpr(right)) - case Token("LIKE", left :: right:: Nil) => Like(nodeToExpr(left), nodeToExpr(right)) - case Token("RLIKE", left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) - case Token("REGEXP", left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) + case Token(LIKE(), left :: right:: Nil) => Like(nodeToExpr(left), nodeToExpr(right)) + case Token(RLIKE(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) + case Token(REGEXP(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) case Token("TOK_FUNCTION", Token("TOK_ISNOTNULL", Nil) :: child :: Nil) => IsNotNull(nodeToExpr(child)) case Token("TOK_FUNCTION", Token("TOK_ISNULL", Nil) :: child :: Nil) => IsNull(nodeToExpr(child)) - case Token("TOK_FUNCTION", Token("IN", Nil) :: value :: list) => + case Token("TOK_FUNCTION", Token(IN(), Nil) :: value :: list) => In(nodeToExpr(value), list.map(nodeToExpr)) case Token("TOK_FUNCTION", Token("between", Nil) :: From 8d85359f84cc67996b4bcf1670a8a98ab4f914a2 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Thu, 5 Jun 2014 23:33:12 -0700 Subject: [PATCH 18/18] [SPARK-1552] Fix type comparison bug in {map,outerJoin}Vertices In GraphImpl, mapVertices and outerJoinVertices use a more efficient implementation when the map function conserves vertex attribute types. This is implemented by comparing the ClassTags of the old and new vertex attribute types. However, ClassTags store erased types, so the comparison will return a false positive for types with different type parameters, such as Option[Int] and Option[Double]. This PR resolves the problem by requesting that the compiler generate evidence of equality between the old and new vertex attribute types, and providing a default value for the evidence parameter if the two types are not equal. The methods can then check the value of the evidence parameter to see whether the types are equal. It also adds a test called "mapVertices changing type with same erased type" that failed before the PR and succeeds now. Callers of mapVertices and outerJoinVertices can no longer use a wildcard for a graph's VD type. To avoid "Error occurred in an application involving default arguments," they must bind VD to a type parameter, as this PR does for ShortestPaths and LabelPropagation. Author: Ankur Dave Closes #967 from ankurdave/SPARK-1552 and squashes the following commits: 68a4fff [Ankur Dave] Undo conserve naming 7388705 [Ankur Dave] Remove unnecessary ClassTag for VD parameters a704e5f [Ankur Dave] Use type equality constraint with default argument 29a5ab7 [Ankur Dave] Add failing test f458c83 [Ankur Dave] Revert "[SPARK-1552] Fix type comparison bug in mapVertices and outerJoinVertices" 16d6af8 [Ankur Dave] [SPARK-1552] Fix type comparison bug in mapVertices and outerJoinVertices --- .../scala/org/apache/spark/graphx/Graph.scala | 5 ++-- .../apache/spark/graphx/impl/GraphImpl.scala | 14 ++++++++--- .../spark/graphx/lib/LabelPropagation.scala | 2 +- .../spark/graphx/lib/ShortestPaths.scala | 2 +- .../org/apache/spark/graphx/GraphSuite.scala | 25 +++++++++++++++++++ 5 files changed, 40 insertions(+), 8 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index 14ae50e6657fd..4db45c9af8fae 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -138,7 +138,8 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * }}} * */ - def mapVertices[VD2: ClassTag](map: (VertexId, VD) => VD2): Graph[VD2, ED] + def mapVertices[VD2: ClassTag](map: (VertexId, VD) => VD2) + (implicit eq: VD =:= VD2 = null): Graph[VD2, ED] /** * Transforms each edge attribute in the graph using the map function. The map function is not @@ -348,7 +349,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * }}} */ def outerJoinVertices[U: ClassTag, VD2: ClassTag](other: RDD[(VertexId, U)]) - (mapFunc: (VertexId, VD, Option[U]) => VD2) + (mapFunc: (VertexId, VD, Option[U]) => VD2)(implicit eq: VD =:= VD2 = null) : Graph[VD2, ED] /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 15ea05cbe281d..ccdaa82eb9162 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -104,8 +104,11 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( new GraphImpl(vertices.reverseRoutingTables(), replicatedVertexView.reverse()) } - override def mapVertices[VD2: ClassTag](f: (VertexId, VD) => VD2): Graph[VD2, ED] = { - if (classTag[VD] equals classTag[VD2]) { + override def mapVertices[VD2: ClassTag] + (f: (VertexId, VD) => VD2)(implicit eq: VD =:= VD2 = null): Graph[VD2, ED] = { + // The implicit parameter eq will be populated by the compiler if VD and VD2 are equal, and left + // null if not + if (eq != null) { vertices.cache() // The map preserves type, so we can use incremental replication val newVerts = vertices.mapVertexPartitions(_.map(f)).cache() @@ -232,8 +235,11 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( override def outerJoinVertices[U: ClassTag, VD2: ClassTag] (other: RDD[(VertexId, U)]) - (updateF: (VertexId, VD, Option[U]) => VD2): Graph[VD2, ED] = { - if (classTag[VD] equals classTag[VD2]) { + (updateF: (VertexId, VD, Option[U]) => VD2) + (implicit eq: VD =:= VD2 = null): Graph[VD2, ED] = { + // The implicit parameter eq will be populated by the compiler if VD and VD2 are equal, and left + // null if not + if (eq != null) { vertices.cache() // updateF preserves type, so we can use incremental replication val newVerts = vertices.leftJoin(other)(updateF).cache() diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala index 776bfb8dd6bfa..82e9e06515179 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala @@ -41,7 +41,7 @@ object LabelPropagation { * * @return a graph with vertex attributes containing the label of community affiliation */ - def run[ED: ClassTag](graph: Graph[_, ED], maxSteps: Int): Graph[VertexId, ED] = { + def run[VD, ED: ClassTag](graph: Graph[VD, ED], maxSteps: Int): Graph[VertexId, ED] = { val lpaGraph = graph.mapVertices { case (vid, _) => vid } def sendMessage(e: EdgeTriplet[VertexId, ED]) = { Iterator((e.srcId, Map(e.dstAttr -> 1L)), (e.dstId, Map(e.srcAttr -> 1L))) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala index bba070f256d80..590f0474957dd 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala @@ -49,7 +49,7 @@ object ShortestPaths { * @return a graph where each vertex attribute is a map containing the shortest-path distance to * each reachable landmark vertex. */ - def run[ED: ClassTag](graph: Graph[_, ED], landmarks: Seq[VertexId]): Graph[SPMap, ED] = { + def run[VD, ED: ClassTag](graph: Graph[VD, ED], landmarks: Seq[VertexId]): Graph[SPMap, ED] = { val spGraph = graph.mapVertices { (vid, attr) => if (landmarks.contains(vid)) makeMap(vid -> 0) else makeMap() } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index abc25d0671133..6506bac73d71c 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -159,6 +159,31 @@ class GraphSuite extends FunSuite with LocalSparkContext { } } + test("mapVertices changing type with same erased type") { + withSpark { sc => + val vertices = sc.parallelize(Array[(Long, Option[java.lang.Integer])]( + (1L, Some(1)), + (2L, Some(2)), + (3L, Some(3)) + )) + val edges = sc.parallelize(Array( + Edge(1L, 2L, 0), + Edge(2L, 3L, 0), + Edge(3L, 1L, 0) + )) + val graph0 = Graph(vertices, edges) + // Trigger initial vertex replication + graph0.triplets.foreach(x => {}) + // Change type of replicated vertices, but preserve erased type + val graph1 = graph0.mapVertices { + case (vid, integerOpt) => integerOpt.map((x: java.lang.Integer) => (x.toDouble): java.lang.Double) + } + // Access replicated vertices, exposing the erased type + val graph2 = graph1.mapTriplets(t => t.srcAttr.get) + assert(graph2.edges.map(_.attr).collect.toSet === Set[java.lang.Double](1.0, 2.0, 3.0)) + } + } + test("mapEdges") { withSpark { sc => val n = 3