Skip to content

Commit

Permalink
[ML][FEATURE] SPARK-5566: RegEx Tokenizer
Browse files Browse the repository at this point in the history
Added a Regex based tokenizer for ml.
Currently the regex is fixed but if I could add a regex type paramater to the paramMap,
changing the tokenizer regex could be a parameter used in the crossValidation.
Also I wonder what would be the best way to add a stop word list.

Author: Augustin Borsu <[email protected]>
Author: Augustin Borsu <[email protected]>
Author: Augustin Borsu <[email protected]>
Author: Xiangrui Meng <[email protected]>

Closes #4504 from aborsu985/master and squashes the following commits:

716d257 [Augustin Borsu] Merge branch 'mengxr-SPARK-5566'
cb07021 [Augustin Borsu] Merge branch 'SPARK-5566' of git://github.com/mengxr/spark into mengxr-SPARK-5566
5f09434 [Augustin Borsu] Merge remote-tracking branch 'upstream/master'
a164800 [Xiangrui Meng] remove tabs
556aa27 [Xiangrui Meng] Merge branch 'aborsu985-master' into SPARK-5566
9651aec [Xiangrui Meng] update test
f96526d [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5566
2338da5 [Augustin Borsu] Merge remote-tracking branch 'upstream/master'
e88d7b8 [Xiangrui Meng] change pattern to a StringParameter; update tests
148126f [Augustin Borsu] Added return type to public functions
12dddb4 [Augustin Borsu] Merge remote-tracking branch 'upstream/master'
daf685e [Augustin Borsu] Merge remote-tracking branch 'upstream/master'
6a85982 [Augustin Borsu] Style corrections
38b95a1 [Augustin Borsu] Added Java unit test for RegexTokenizer
b66313f [Augustin Borsu] Modified the pattern Param so it is compiled when given to the Tokenizer
e262bac [Augustin Borsu] Added unit tests in scala
cd6642e [Augustin Borsu] Changed regex to pattern
132b00b [Augustin Borsu] Changed matching to gaps and removed case folding
201a107 [Augustin Borsu] Merge remote-tracking branch 'upstream/master'
cb9c9a7 [Augustin Borsu] Merge remote-tracking branch 'upstream/master'
d3ef6d3 [Augustin Borsu] Added doc to RegexTokenizer
9082fc3 [Augustin Borsu] Removed stopwords parameters and updated doc
19f9e53 [Augustin Borsu] Merge remote-tracking branch 'upstream/master'
f6a5002 [Augustin Borsu] Merge remote-tracking branch 'upstream/master'
7f930bb [Augustin Borsu] Merge remote-tracking branch 'upstream/master'
77ff9ca [Augustin Borsu] Merge remote-tracking branch 'upstream/master'
2e89719 [Augustin Borsu] Merge remote-tracking branch 'upstream/master'
196cd7a [Augustin Borsu] Merge remote-tracking branch 'upstream/master'
11ca50f [Augustin Borsu] Merge remote-tracking branch 'upstream/master'
9f8685a [Augustin Borsu] RegexTokenizer
9e07a78 [Augustin Borsu] Merge remote-tracking branch 'upstream/master'
9547e9d [Augustin Borsu] RegEx Tokenizer
01cd26f [Augustin Borsu] RegExTokenizer
  • Loading branch information
Augustin Borsu authored and mengxr committed Mar 25, 2015
1 parent 10c7860 commit 982952f
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 1 deletion.
66 changes: 65 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature

import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.{ParamMap, IntParam, BooleanParam, Param}
import org.apache.spark.sql.types.{DataType, StringType, ArrayType}

/**
Expand All @@ -39,3 +39,67 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {

override protected def outputDataType: DataType = new ArrayType(StringType, false)
}

/**
* :: AlphaComponent ::
* A regex based tokenizer that extracts tokens either by repeatedly matching the regex(default)
* or using it to split the text (set matching to false). Optional parameters also allow to fold
* the text to lowercase prior to it being tokenized and to filer tokens using a minimal length.
* It returns an array of strings that can be empty.
* The default parameters are regex = "\\p{L}+|[^\\p{L}\\s]+", matching = true,
* lowercase = false, minTokenLength = 1
*/
@AlphaComponent
class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenizer] {

/**
* param for minimum token length, default is one to avoid returning empty strings
* @group param
*/
val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length", Some(1))

/** @group setParam */
def setMinTokenLength(value: Int): this.type = set(minTokenLength, value)

/** @group getParam */
def getMinTokenLength: Int = get(minTokenLength)

/**
* param sets regex as splitting on gaps (true) or matching tokens (false)
* @group param
*/
val gaps: BooleanParam = new BooleanParam(
this, "gaps", "Set regex to match gaps or tokens", Some(false))

/** @group setParam */
def setGaps(value: Boolean): this.type = set(gaps, value)

/** @group getParam */
def getGaps: Boolean = get(gaps)

/**
* param sets regex pattern used by tokenizer
* @group param
*/
val pattern: Param[String] = new Param(
this, "pattern", "regex pattern used for tokenizing", Some("\\p{L}+|[^\\p{L}\\s]+"))

/** @group setParam */
def setPattern(value: String): this.type = set(pattern, value)

/** @group getParam */
def getPattern: String = get(pattern)

override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { str =>
val re = paramMap(pattern).r
val tokens = if (paramMap(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq
val minLength = paramMap(minTokenLength)
tokens.filter(_.length >= minLength)
}

override protected def validateInputType(inputType: DataType): Unit = {
require(inputType == StringType, s"Input type must be string type but got $inputType.")
}

override protected def outputDataType: DataType = new ArrayType(StringType, false)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature;

import com.google.common.collect.Lists;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;

public class JavaTokenizerSuite {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;

@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaTokenizerSuite");
jsql = new SQLContext(jsc);
}

@After
public void tearDown() {
jsc.stop();
jsc = null;
}

@Test
public void regexTokenizer() {
RegexTokenizer myRegExTokenizer = new RegexTokenizer()
.setInputCol("rawText")
.setOutputCol("tokens")
.setPattern("\\s")
.setGaps(true)
.setMinTokenLength(3);

JavaRDD<TokenizerTestData> rdd = jsc.parallelize(Lists.newArrayList(
new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}),
new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"})
));
DataFrame dataset = jsql.createDataFrame(rdd, TokenizerTestData.class);

Row[] pairs = myRegExTokenizer.transform(dataset)
.select("tokens", "wantedTokens")
.collect();

for (Row r : pairs) {
Assert.assertEquals(r.get(0), r.get(1));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import scala.beans.BeanInfo

import org.scalatest.FunSuite

import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row, SQLContext}

@BeanInfo
case class TokenizerTestData(rawText: String, wantedTokens: Seq[String]) {
/** Constructor used in [[org.apache.spark.ml.feature.JavaTokenizerSuite]] */
def this(rawText: String, wantedTokens: Array[String]) = this(rawText, wantedTokens.toSeq)
}

class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext {
import org.apache.spark.ml.feature.RegexTokenizerSuite._

@transient var sqlContext: SQLContext = _

override def beforeAll(): Unit = {
super.beforeAll()
sqlContext = new SQLContext(sc)
}

test("RegexTokenizer") {
val tokenizer = new RegexTokenizer()
.setInputCol("rawText")
.setOutputCol("tokens")

val dataset0 = sqlContext.createDataFrame(Seq(
TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization", ".")),
TokenizerTestData("Te,st. punct", Seq("Te", ",", "st", ".", "punct"))
))
testRegexTokenizer(tokenizer, dataset0)

val dataset1 = sqlContext.createDataFrame(Seq(
TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization")),
TokenizerTestData("Te,st. punct", Seq("punct"))
))

tokenizer.setMinTokenLength(3)
testRegexTokenizer(tokenizer, dataset1)

tokenizer
.setPattern("\\s")
.setGaps(true)
.setMinTokenLength(0)
val dataset2 = sqlContext.createDataFrame(Seq(
TokenizerTestData("Test for tokenization.", Seq("Test", "for", "tokenization.")),
TokenizerTestData("Te,st. punct", Seq("Te,st.", "", "punct"))
))
testRegexTokenizer(tokenizer, dataset2)
}
}

object RegexTokenizerSuite extends FunSuite {

def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = {
t.transform(dataset)
.select("tokens", "wantedTokens")
.collect()
.foreach {
case Row(tokens, wantedTokens) =>
assert(tokens === wantedTokens)
}
}
}

0 comments on commit 982952f

Please sign in to comment.