Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

restructure the directories of zoo project #7

Merged
merged 5 commits into from
May 22, 2017
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add segmenter
JerryYanWan committed May 22, 2017

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 4b1db6452b81799f066459de8f970d18eac77b69
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@ import com.intel.analytics.deepspeech2.util.{LocalOptimizerPerfParam, parser}
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.feature.FlacReader
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}

/**
* load trained model to inference the audio files.
@@ -32,13 +32,13 @@ object InferenceEvaluate {
val st = System.nanoTime()

val df = dataLoader(spark, param.dataPath, param.numFile, param.partition)

df.repartition(param.partition)
logger.info(s"${df.count()} audio files, in ${df.rdd.partitions.length} partitions")
df.show()

val pipeline = getPipeline(param.modelPath, uttLength, windowSize, windowStride, numFilters)
val pipeline = getPipeline(param.modelPath, uttLength, windowSize, windowStride, numFilters, sampleRate, param.segment)
val model = pipeline.fit(df)
evaluate(model, df)
evaluate(model, df, param.segment)

logger.info("total time = " + (System.nanoTime() - st) / 1e9)
}
@@ -88,10 +88,14 @@ object InferenceEvaluate {
}

private def getPipeline(modelPath: String, uttLength: Int, windowSize: Int,
windowStride: Int, numFilter: Int): Pipeline = {
windowStride: Int, numFilter: Int, sampleRate: Int, isSegment: Boolean): Pipeline = {

val windower = new Windower()
val segmenter = new TimeSegmenter()
.setSegmentSize(sampleRate * 4)
.setInputCol("samples")
.setOutputCol("segments")
val windower = new Windower()
.setInputCol("segments")
.setOutputCol("window")
.setOriginalSizeCol("originalSizeCol")
.setWindowShift(windowStride)
@@ -123,12 +127,46 @@ object InferenceEvaluate {
.setUttLength(uttLength)
.setWindowSize(windowSize)

new Pipeline().setStages(
Array(windower, dftSpecgram, melbank, transposeFlip, modelTransformer, decoder))
val pipeline = new Pipeline()

if (isSegment) {
pipeline.setStages(
Array(segmenter, windower, dftSpecgram, melbank, transposeFlip, modelTransformer, decoder))
} else {
windower
.setInputCol("samples")
pipeline.setStages(
Array(windower, dftSpecgram, melbank, transposeFlip, modelTransformer, decoder))
}

pipeline
}

private def evaluate(model: PipelineModel, df: DataFrame): Unit = {
val result = model.transform(df).select("path", "output", "target").cache()
private def evaluate(model: PipelineModel, df: DataFrame, isSegment: Boolean): Unit = {

val result = if (isSegment) {
val results = model.transform(df).select("path", "target", "audio_id", "audio_seq", "output").cache()
results.select("path", "audio_id", "audio_seq", "output").show(false)

val grouped = results.rdd.map {
case Row(path: String, target: String, audio_id: Long, audio_seq: Int, output: String) =>
(audio_id, (path, target, audio_seq, output))
}.groupByKey()
.map(_._2)
.map { iter =>
val path = iter.head._1
val target = iter.head._2
val text = iter.toArray.sortBy(_._3).map(_._4).mkString(" ")
(path, text, target)
}

val spark = df.sparkSession
import spark.implicits._
spark.createDataset(grouped).toDF("path", "output", "target")
} else {
model.transform(df).select("path", "output", "target").cache()
}

logger.info(s"evaluation result:")
result.select("output", "target").rdd.map { r =>
val output = r.getString(0)
@@ -146,5 +184,4 @@ object InferenceEvaluate {
.setPredictionCol("output").setMetricName("wer").evaluate(result)
logger.info("wer = " + wer)
}

}
Original file line number Diff line number Diff line change
@@ -34,15 +34,17 @@ class TimeSegmenter ( override val uid: String)
override def transform(dataset: Dataset[_]): DataFrame = {
val outputSchema = transformSchema(dataset.schema)

val rows = dataset.select("path", $(inputCol)).rdd.zipWithIndex().flatMap { case (r, id) =>
val rows = dataset.select("path", "target", $(inputCol)).rdd.zipWithIndex().flatMap { case (r, id) =>
val path = r.getAs[String](0)
val arr = r.getSeq[Float](1).toArray.grouped($(segmentSize))
val target = r.getAs[String](1)
val arr = r.getSeq[Float](2).toArray.grouped($(segmentSize))
arr.zipWithIndex.map { case (data, seq) =>
Row(path, id, seq, data)
Row(path, target, id, seq, data)
}
}
val schema = StructType(Seq(
StructField("path", StringType, nullable = false),
StructField("target", StringType, nullable = false),
StructField("audio_id", LongType, nullable = false),
StructField("audio_seq", IntegerType, nullable = false),
StructField($(outputCol), dataset.schema($(inputCol)).dataType, nullable = false)
@@ -65,4 +67,3 @@ object TimeSegmenter extends DefaultParamsReadable[TimeSegmenter] {

override def load(path: String): TimeSegmenter = super.load(path)
}

Original file line number Diff line number Diff line change
@@ -20,11 +20,15 @@ object parser {
opt[Int]('n', "num")
.text("file number. Default is 8")
.action((v, p) => p.copy(numFile = v))
opt[Boolean]('s', "segment")
.text("whether to segment audio or not. Default is false")
.action((v, p) => p.copy(segment = v))
}
}

case class LocalOptimizerPerfParam(
dataPath: String = null,
modelPath: String = null,
partition: Int = 4,
numFile: Int = 8 )
numFile: Int = 8,
segment: Boolean = false)