Skip to content

Commit

Permalink
Removed stopwords parameters and updated doc
Browse files Browse the repository at this point in the history
Still need to add unit test.
  • Loading branch information
Augustin Borsu committed Mar 3, 2015
1 parent 19f9e53 commit 9082fc3
Showing 1 changed file with 22 additions and 39 deletions.
61 changes: 22 additions & 39 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,57 +43,40 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {

/**
* :: AlphaComponent ::
* A regex based tokenizer that extracts tokens using a regex.
* Optional additional parameters include enabling lowercase stabdarization, a minimum character
* size for tokens as well as an array of stop words to remove from the results.
* A regex based tokenizer that extracts tokens either by repeatedly matching the regex(default)
* or using it to split the text (set matching to false). Optional parameters also allow to fold
* the text to lowercase prior to it being tokenized and to filer tokens using a minimal length.
* It returns an array of strings that can be empty.
* The default parameters are regex = "\\p{L}+|[^\\p{L}\\s]+", matching = true,
* lowercase = false, minTokenLength = 1
*/
@AlphaComponent
class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenizer] {

val lowerCase = new BooleanParam(this,
"lowerCase",
"enable case folding to lower case",
Some(true))
val lowerCase = new BooleanParam(this, "lowerCase", "Folds case to lower case", Some(false))
def setLowercase(value: Boolean) = set(lowerCase, value)
def getLowercase: Boolean = get(lowerCase)

val minLength = new IntParam(this,
"minLength",
"minimum token length (excluded)",
Some(0))
def setMinLength(value: Int) = set(minLength, value)
def getMinLength: Int = get(minLength)
val minTokenLength = new IntParam(this, "minLength", "minimum token length", Some(1))
def setMinTokenLength(value: Int) = set(minTokenLength, value)
def getMinTokenLength: Int = get(minTokenLength)

val regEx = new Param(this,
"regEx",
"RegEx used for tokenizing",
Some("\\p{L}+|[^\\p{L}\\s]+".r))
def setRegex(value: scala.util.matching.Regex) = set(regEx, value)
def getRegex: scala.util.matching.Regex = get(regEx)

val stopWords = new Param(this,
"stopWords",
"array of tokens to filter from results",
Some(Array[String]()))
def setStopWords(value: Array[String]) = set(stopWords, value)
def getStopWords: Array[String] = get(stopWords)
val matching = new BooleanParam(this, "matching", "Sets regex to matching or split", Some(true))
def setMatching(value: Boolean) = set(matching, value)
def getMatching: Boolean = get(matching)

val regex = new Param(this, "regex", "regex used for tokenizing", Some("\\p{L}+|[^\\p{L}\\s]+"))
def setRegex(value: String) = set(regex, value)
def getRegex: String = get(regex)

override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { x =>

var string = x
if (paramMap(lowerCase)) {
string = string.toLowerCase
}
var tokens = (paramMap(regEx) findAllIn string).toList

if(paramMap(minLength) > 0){
tokens = tokens.filter(_.length > paramMap(minLength))
}
if(paramMap(stopWords).length > 0){
tokens = tokens.filter(!paramMap(stopWords).contains(_))
}
tokens
val str = if (paramMap(lowerCase)) x.toLowerCase else x

val re = paramMap(regex)
val tokens = if(paramMap(matching))(re.r.findAllIn(str)).toList else str.split(re).toList

tokens.filter(_.length >= paramMap(minTokenLength))
}

override protected def validateInputType(inputType: DataType): Unit = {
Expand Down

0 comments on commit 9082fc3

Please sign in to comment.