Skip to content

Commit

Permalink
Added return type to public functions
Browse files Browse the repository at this point in the history
Plus some cosmetic changes.
  • Loading branch information
sagacifyTestUser committed Mar 23, 2015
1 parent 12dddb4 commit 148126f
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 35 deletions.
13 changes: 7 additions & 6 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
* param for minimum token length, default is one to avoid returning empty strings
* @group param
*/
val minTokenLength = new IntParam(this, "minLength", "minimum token length", Some(1))
val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length", Some(1))

/** @group setParam */
def setMinTokenLength(value: Int) = set(minTokenLength, value)
def setMinTokenLength(value: Int): this.type = set(minTokenLength, value)

/** @group getParam */
def getMinTokenLength: Int = get(minTokenLength)
Expand All @@ -68,10 +68,11 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
* param sets regex as splitting on gaps (true) or matching tokens (false)
* @group param
*/
val gaps = new BooleanParam(this, "gaps", "Set regex to match gaps or tokens", Some(false))
val gaps: BooleanParam = new BooleanParam(this, "gaps",
"Set regex to match gaps or tokens", Some(false))

/** @group setParam */
def setGaps(value: Boolean) = set(gaps, value)
def setGaps(value: Boolean): this.type = set(gaps, value)

/** @group getParam */
def getGaps: Boolean = get(gaps)
Expand All @@ -80,11 +81,11 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
* param sets regex pattern used by tokenizer
* @group param
*/
val pattern = new Param(this, "pattern",
val pattern: Param[scala.util.matching.Regex] = new Param(this, "pattern",
"regex pattern used for tokenizing", Some("\\p{L}+|[^\\p{L}\\s]+".r))

/** @group setParam */
def setPattern(value: String) = set(pattern, value.r)
def setPattern(value: String): this.type = set(pattern, value.r)

/** @group getParam */
def getPattern: String = get(pattern).toString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public void RegexTokenizer() {

List<String> t = Arrays.asList(
"{\"rawText\": \"Test of tok.\", \"wantedTokens\": [\"Test\", \"of\", \"tok.\"]}",
"{\"rawText\": \"Te,st. punct\", \"wantedTokens\": [\"Te,st.\",\"\",\"punct\"]}");
"{\"rawText\": \"Te,st. punct\", \"wantedTokens\": [\"Te,st.\", \"\", \"punct\"]}");

JavaRDD<String> myRdd = jsc.parallelize(t);
DataFrame dataset = jsql.jsonRDD(myRdd);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ import org.apache.spark.SparkException
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row, SQLContext}

case class TextData(rawText : String, wantedTokens: Seq[String])
case class TextData(rawText: String, wantedTokens: Seq[String])

class TokenizerSuite extends FunSuite with MLlibTestSparkContext {

@transient var sqlContext: SQLContext = _
Expand All @@ -33,70 +34,66 @@ class TokenizerSuite extends FunSuite with MLlibTestSparkContext {
sqlContext = new SQLContext(sc)
}

test("RegexTokenizer"){
test("RegexTokenizer") {
val myRegExTokenizer = new RegexTokenizer()
.setInputCol("rawText")
.setOutputCol("tokens")

var dataset = sqlContext.createDataFrame(
sc.parallelize(List(
TextData("Test for tokenization.",List("Test","for","tokenization",".")),
TextData("Te,st. punct",List("Te",",","st",".","punct"))
sc.parallelize(Seq(
TextData("Test for tokenization.", Seq("Test", "for", "tokenization", ".")),
TextData("Te,st. punct", Seq("Te", ",", "st", ".", "punct"))
)))
testRegexTokenizer(myRegExTokenizer,dataset)
testRegexTokenizer(myRegExTokenizer, dataset)

dataset = sqlContext.createDataFrame(
sc.parallelize(List(
TextData("Test for tokenization.",List("Test","for","tokenization")),
TextData("Te,st. punct",List("punct"))
sc.parallelize(Seq(
TextData("Test for tokenization.", Seq("Test", "for", "tokenization")),
TextData("Te,st. punct", Seq("punct"))
)))
myRegExTokenizer.asInstanceOf[RegexTokenizer]
.setMinTokenLength(3)
testRegexTokenizer(myRegExTokenizer,dataset)
testRegexTokenizer(myRegExTokenizer, dataset)

myRegExTokenizer.asInstanceOf[RegexTokenizer]
.setPattern("\\s")
.setGaps(true)
.setMinTokenLength(0)
dataset = sqlContext.createDataFrame(
sc.parallelize(List(
TextData("Test for tokenization.",List("Test","for","tokenization.")),
TextData("Te,st. punct",List("Te,st.","","punct"))
sc.parallelize(Seq(
TextData("Test for tokenization.", Seq("Test", "for", "tokenization.")),
TextData("Te,st. punct", Seq("Te,st.", "", "punct"))
)))
testRegexTokenizer(myRegExTokenizer,dataset)
testRegexTokenizer(myRegExTokenizer, dataset)
}

test("Tokenizer") {
val oldTokenizer = new Tokenizer()
.setInputCol("rawText")
.setOutputCol("tokens")
var dataset = sqlContext.createDataFrame(
sc.parallelize(List(
TextData("Test for tokenization.",List("test","for","tokenization.")),
TextData("Te,st. punct",List("te,st.","","punct"))
sc.parallelize(Seq(
TextData("Test for tokenization.", Seq("test", "for", "tokenization.")),
TextData("Te,st. punct", Seq("te,st.", "", "punct"))
)))
testTokenizer(oldTokenizer,dataset)
testTokenizer(oldTokenizer, dataset)
}

def testTokenizer(t: Tokenizer,dataset: DataFrame): Unit = {
def testTokenizer(t: Tokenizer, dataset: DataFrame): Unit = {
t.transform(dataset)
.select("tokens","wantedTokens")
.collect().foreach{
.select("tokens", "wantedTokens")
.collect().foreach {
case Row(tokens: Seq[Any], wantedTokens: Seq[Any]) =>
assert(tokens === wantedTokens)
case e =>
throw new SparkException(s"Row $e should contain only tokens and wantedTokens columns")
}
}

def testRegexTokenizer(t: RegexTokenizer,dataset: DataFrame): Unit = {
def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = {
t.transform(dataset)
.select("tokens","wantedTokens")
.collect().foreach{
.select("tokens", "wantedTokens")
.collect().foreach {
case Row(tokens: Seq[Any], wantedTokens: Seq[Any]) =>
assert(tokens === wantedTokens)
case e =>
throw new SparkException(s"Row $e should contain only tokens and wantedTokens columns")
}
}

Expand Down

0 comments on commit 148126f

Please sign in to comment.