diff --git a/python/test/annotator/classifier_dl/albert_for_multiple_choice_test.py b/python/test/annotator/classifier_dl/albert_for_multiple_choice_test.py index 35a92b13908fb5..6e42465e8a2cea 100644 --- a/python/test/annotator/classifier_dl/albert_for_multiple_choice_test.py +++ b/python/test/annotator/classifier_dl/albert_for_multiple_choice_test.py @@ -46,7 +46,7 @@ def setUp(self): self.pipeline_model = pipeline.fit(empty_df) -# @pytest.mark.slow +@pytest.mark.slow class AlbertForMultipleChoiceTest(AlbertForMultipleChoiceTestSetup, unittest.TestCase): def setUp(self): @@ -61,7 +61,7 @@ def test_run(self): self.assertTrue(row["answer"][0].result != "") -# @pytest.mark.slow +@pytest.mark.slow class LightAlbertForMultipleChoiceTest(AlbertForMultipleChoiceTestSetup, unittest.TestCase): def setUp(self): diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForMultipleChoiceTest.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForMultipleChoiceTest.scala index 1522e4d22b632d..1385f41d94f502 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForMultipleChoiceTest.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForMultipleChoiceTest.scala @@ -1,50 +1,61 @@ package com.johnsnowlabs.nlp.annotators.classifier.dl -import com.johnsnowlabs.nlp.MultiDocumentAssembler +import com.johnsnowlabs.nlp.{Annotation, AssertAnnotations, MultiDocumentAssembler} import com.johnsnowlabs.nlp.annotators.SparkSessionTest +import com.johnsnowlabs.nlp.base.LightPipeline +import com.johnsnowlabs.tags.SlowTest import org.apache.spark.ml.Pipeline import org.scalatest.flatspec.AnyFlatSpec class AlbertForMultipleChoiceTest extends AnyFlatSpec with SparkSessionTest { import spark.implicits._ - val onnxModelPath = "/media/danilo/Data/Danilo/JSL/models/transformers/onnx" - val sparkNLPModelPath = "/media/danilo/Data/Danilo/JSL/models/transformers/spark-nlp" - val openVinoModelPath = "/media/danilo/Data/Danilo/JSL/models/transformers/openvino" + + lazy val pipelineModel = getAlbertForMultipleChoicePipelineModel val testDataframe = Seq(("The Eiffel Tower is located in which country?", "Germany, France, Italy")) .toDF("question", "context") - "AlbertForMultipleChoice" should "loadSavedModel ONNX model" in { - val albertForMultipleChoice = AlbertForMultipleChoice.loadSavedModel(s"$onnxModelPath/albert_multiple_choice", spark) - albertForMultipleChoice.write.overwrite.save(s"$sparkNLPModelPath/onnx/albert_multiple_choice_onnx") - } + "AlbertForMultipleChoice" should "answer a multiple choice question" taggedAs SlowTest in { + val resultDf = pipelineModel.transform(testDataframe) + resultDf.show(truncate = false) - it should "loadSavedModel OpenVINO model" in { - val albertForMultipleChoice = AlbertForMultipleChoice.loadSavedModel(s"$openVinoModelPath/albert_multiple_choice_openvino", spark) - albertForMultipleChoice.write.overwrite.save(s"$sparkNLPModelPath/openvino/albert_multiple_choice_openvino") + val result = AssertAnnotations.getActualResult(resultDf, "answer") + result.foreach { annotation => + annotation.foreach(a => assert(a.result.nonEmpty)) + } } - it should "work for ONNX" in { - val pipelineModel = getAlbertForMultipleChoicePipelineModel(s"$sparkNLPModelPath/onnx/albert_multiple_choice_onnx") - val resultDf = pipelineModel.transform(testDataframe) - resultDf.show(truncate = false) + it should "work with light pipeline fullAnnotate" taggedAs SlowTest in { + val lightPipeline = new LightPipeline(pipelineModel) + val resultFullAnnotate = lightPipeline.fullAnnotate( + "The Eiffel Tower is located in which country?", + "Germany, France, Italy") + println(s"resultAnnotate: $resultFullAnnotate") + + val answerAnnotation = resultFullAnnotate("answer").head.asInstanceOf[Annotation] + + assert(answerAnnotation.result.nonEmpty) } - it should "work for OpenVINO" in { - val pipelineModel = getAlbertForMultipleChoicePipelineModel(s"$sparkNLPModelPath/openvino/albert_multiple_choice_openvino") - val resultDf = pipelineModel.transform(testDataframe) - resultDf.show(truncate = false) + it should "work with light pipeline annotate" taggedAs SlowTest in { + val lightPipeline = new LightPipeline(pipelineModel) + val resultAnnotate = lightPipeline.annotate( + "The Eiffel Tower is located in which country?", + "Germany, France, Italy") + println(s"resultAnnotate: $resultAnnotate") + + assert(resultAnnotate("answer").head.nonEmpty) } - private def getAlbertForMultipleChoicePipelineModel(modelPath: String) = { + private def getAlbertForMultipleChoicePipelineModel = { val documentAssembler = new MultiDocumentAssembler() .setInputCols("question", "context") .setOutputCols("document_question", "document_context") val bertForMultipleChoice = AlbertForMultipleChoice - .load(modelPath) + .pretrained() .setInputCols("document_question", "document_context") .setOutputCol("answer")