Skip to content

Commit

Permalink
Modified the pattern Param so it is compiled when given to the Tokenizer
Browse files Browse the repository at this point in the history
pattern is set and get as a string but stored as a compiled regex this
prevents having to recompile it everytime the transform function is called
  • Loading branch information
Augustin Borsu committed Mar 19, 2015
1 parent e262bac commit b66313f
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {
* lowercase = false, minTokenLength = 1
*/
@AlphaComponent
class RegexTokenizer extends Tokenizer {
class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenizer] {

/**
* param for minimum token length, default is one to avoid returning empty strings
Expand All @@ -65,7 +65,7 @@ class RegexTokenizer extends Tokenizer {
def getMinTokenLength: Int = get(minTokenLength)

/**
* param sets regex as splitting on gaps(true) or matching tokens (false)
* param sets regex as splitting on gaps (true) or matching tokens (false)
* @group param
*/
val gaps = new BooleanParam(this, "gaps", "Set regex to match gaps or tokens", Some(false))
Expand All @@ -81,20 +81,20 @@ class RegexTokenizer extends Tokenizer {
* @group param
*/
val pattern = new Param(this, "pattern",
"regex pattern used for tokenizing", Some("\\p{L}+|[^\\p{L}\\s]+"))
"regex pattern used for tokenizing", Some("\\p{L}+|[^\\p{L}\\s]+".r))

/** @group setParam */
def setPattern(value: String) = set(pattern, value)
def setPattern(value: String) = set(pattern, value.r)

/** @group getParam */
def getPattern: String = get(pattern)
def getPattern: String = get(pattern).toString

override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { str =>

val re = paramMap(pattern)
val tokens = if(paramMap(gaps)) str.split(re).toList else (re.r.findAllIn(str)).toList
val tokens = if(paramMap(gaps)) re.split(str).toList else (re.findAllIn(str)).toList

tokens.filter(_.length >= paramMap(minTokenLength)).toSeq
tokens.filter(_.length >= paramMap(minTokenLength))
}

override protected def validateInputType(inputType: DataType): Unit = {
Expand Down

0 comments on commit b66313f

Please sign in to comment.