Skip to content

Commit

Permalink
[SPARKNLP-1105] Addiong test tags
Browse files Browse the repository at this point in the history
  • Loading branch information
danilojsl committed Dec 27, 2024
1 parent 1ea76c3 commit 9e67d89
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def setUp(self):
self.pipeline_model = pipeline.fit(empty_df)


# @pytest.mark.slow
@pytest.mark.slow
class AlbertForMultipleChoiceTest(AlbertForMultipleChoiceTestSetup, unittest.TestCase):

def setUp(self):
Expand All @@ -61,7 +61,7 @@ def test_run(self):
self.assertTrue(row["answer"][0].result != "")


# @pytest.mark.slow
@pytest.mark.slow
class LightAlbertForMultipleChoiceTest(AlbertForMultipleChoiceTestSetup, unittest.TestCase):

def setUp(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,50 +1,61 @@
package com.johnsnowlabs.nlp.annotators.classifier.dl

import com.johnsnowlabs.nlp.MultiDocumentAssembler
import com.johnsnowlabs.nlp.{Annotation, AssertAnnotations, MultiDocumentAssembler}
import com.johnsnowlabs.nlp.annotators.SparkSessionTest
import com.johnsnowlabs.nlp.base.LightPipeline
import com.johnsnowlabs.tags.SlowTest
import org.apache.spark.ml.Pipeline
import org.scalatest.flatspec.AnyFlatSpec

class AlbertForMultipleChoiceTest extends AnyFlatSpec with SparkSessionTest {

import spark.implicits._
val onnxModelPath = "/media/danilo/Data/Danilo/JSL/models/transformers/onnx"
val sparkNLPModelPath = "/media/danilo/Data/Danilo/JSL/models/transformers/spark-nlp"
val openVinoModelPath = "/media/danilo/Data/Danilo/JSL/models/transformers/openvino"

lazy val pipelineModel = getAlbertForMultipleChoicePipelineModel

val testDataframe =
Seq(("The Eiffel Tower is located in which country?", "Germany, France, Italy"))
.toDF("question", "context")

"AlbertForMultipleChoice" should "loadSavedModel ONNX model" in {
val albertForMultipleChoice = AlbertForMultipleChoice.loadSavedModel(s"$onnxModelPath/albert_multiple_choice", spark)
albertForMultipleChoice.write.overwrite.save(s"$sparkNLPModelPath/onnx/albert_multiple_choice_onnx")
}
"AlbertForMultipleChoice" should "answer a multiple choice question" taggedAs SlowTest in {
val resultDf = pipelineModel.transform(testDataframe)
resultDf.show(truncate = false)

it should "loadSavedModel OpenVINO model" in {
val albertForMultipleChoice = AlbertForMultipleChoice.loadSavedModel(s"$openVinoModelPath/albert_multiple_choice_openvino", spark)
albertForMultipleChoice.write.overwrite.save(s"$sparkNLPModelPath/openvino/albert_multiple_choice_openvino")
val result = AssertAnnotations.getActualResult(resultDf, "answer")
result.foreach { annotation =>
annotation.foreach(a => assert(a.result.nonEmpty))
}
}

it should "work for ONNX" in {
val pipelineModel = getAlbertForMultipleChoicePipelineModel(s"$sparkNLPModelPath/onnx/albert_multiple_choice_onnx")
val resultDf = pipelineModel.transform(testDataframe)
resultDf.show(truncate = false)
it should "work with light pipeline fullAnnotate" taggedAs SlowTest in {
val lightPipeline = new LightPipeline(pipelineModel)
val resultFullAnnotate = lightPipeline.fullAnnotate(
"The Eiffel Tower is located in which country?",
"Germany, France, Italy")
println(s"resultAnnotate: $resultFullAnnotate")

val answerAnnotation = resultFullAnnotate("answer").head.asInstanceOf[Annotation]

assert(answerAnnotation.result.nonEmpty)
}

it should "work for OpenVINO" in {
val pipelineModel = getAlbertForMultipleChoicePipelineModel(s"$sparkNLPModelPath/openvino/albert_multiple_choice_openvino")
val resultDf = pipelineModel.transform(testDataframe)
resultDf.show(truncate = false)
it should "work with light pipeline annotate" taggedAs SlowTest in {
val lightPipeline = new LightPipeline(pipelineModel)
val resultAnnotate = lightPipeline.annotate(
"The Eiffel Tower is located in which country?",
"Germany, France, Italy")
println(s"resultAnnotate: $resultAnnotate")

assert(resultAnnotate("answer").head.nonEmpty)
}

private def getAlbertForMultipleChoicePipelineModel(modelPath: String) = {
private def getAlbertForMultipleChoicePipelineModel = {
val documentAssembler = new MultiDocumentAssembler()
.setInputCols("question", "context")
.setOutputCols("document_question", "document_context")

val bertForMultipleChoice = AlbertForMultipleChoice
.load(modelPath)
.pretrained()
.setInputCols("document_question", "document_context")
.setOutputCol("answer")

Expand Down

0 comments on commit 9e67d89

Please sign in to comment.