-
Notifications
You must be signed in to change notification settings - Fork 138
Commit
Signed-off-by: Bhavana Ramaram <[email protected]>
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
/* | ||
* Copyright 2023 Aryn | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.opensearch.ml.common.dataset; | ||
|
||
import lombok.AccessLevel; | ||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.experimental.FieldDefaults; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.ml.common.annotation.InputDataSet; | ||
|
||
import java.io.IOException; | ||
import java.util.Objects; | ||
|
||
@Getter | ||
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) | ||
@InputDataSet(MLInputDataType.QUESTION_ANSWERING) | ||
public class QuestionAnsweringInputDataSet extends MLInputDataset { | ||
|
||
String question; | ||
|
||
String context; | ||
|
||
@Builder(toBuilder = true) | ||
public QuestionAnsweringInputDataSet(String question, String context) { | ||
super(MLInputDataType.QUESTION_ANSWERING); | ||
if(question == null) { | ||
throw new IllegalArgumentException("Question is not provided"); | ||
} | ||
if(context == null) { | ||
throw new IllegalArgumentException("Context is not provided"); | ||
} | ||
this.question = question; | ||
this.context = context; | ||
} | ||
|
||
public QuestionAnsweringInputDataSet(StreamInput in) throws IOException { | ||
super(MLInputDataType.TEXT_SIMILARITY); | ||
this.question = in.readString(); | ||
this.context = in.readString(); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
out.writeString(question); | ||
out.writeString(context); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
/* | ||
* Copyright 2023 Aryn | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.opensearch.ml.common.input.nlp; | ||
|
||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.core.xcontent.XContentBuilder; | ||
import org.opensearch.core.xcontent.XContentParser; | ||
import org.opensearch.ml.common.FunctionName; | ||
import org.opensearch.ml.common.dataset.MLInputDataset; | ||
import org.opensearch.ml.common.dataset.QuestionAnsweringInputDataSet; | ||
import org.opensearch.ml.common.input.MLInput; | ||
|
||
import java.io.IOException; | ||
|
||
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; | ||
|
||
|
||
/** | ||
* MLInput which supports a text similarity algorithm | ||
* Inputs are a query and a list of texts. Outputs are real numbers | ||
* Use this for Cross Encoder models | ||
*/ | ||
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.QUESTION_ANSWERING}) | ||
public class QuestionAnsweringMLInput extends MLInput { | ||
|
||
public QuestionAnsweringMLInput(FunctionName algorithm, MLInputDataset dataset) { | ||
super(algorithm, null, dataset); | ||
} | ||
|
||
public QuestionAnsweringMLInput(StreamInput in) throws IOException { | ||
super(in); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
super.writeTo(out); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
builder.field(ALGORITHM_FIELD, algorithm.name()); | ||
if(parameters != null) { | ||
builder.field(ML_PARAMETERS_FIELD, parameters); | ||
Check warning on line 60 in common/src/main/java/org/opensearch/ml/common/input/nlp/QuestionAnsweringMLInput.java Codecov / codecov/patchcommon/src/main/java/org/opensearch/ml/common/input/nlp/QuestionAnsweringMLInput.java#L60
|
||
} | ||
if(inputDataset != null) { | ||
QuestionAnsweringInputDataSet ds = (QuestionAnsweringInputDataSet) this.inputDataset; | ||
String question = ds.getQuestion(); | ||
String context = ds.getContext(); | ||
builder.field(QUESTION_FIELD, question); | ||
builder.field(CONTEXT_FIELD, context); | ||
} | ||
builder.endObject(); | ||
return builder; | ||
} | ||
|
||
public QuestionAnsweringMLInput(XContentParser parser, FunctionName functionName) throws IOException { | ||
super(); | ||
this.algorithm = functionName; | ||
String question = null; | ||
String context = null; | ||
|
||
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); | ||
while (parser.nextToken() != XContentParser.Token.END_OBJECT) { | ||
String fieldName = parser.currentName(); | ||
parser.nextToken(); | ||
|
||
switch (fieldName) { | ||
case QUESTION_FIELD: | ||
question = parser.text(); | ||
case CONTEXT_FIELD: | ||
context = parser.text(); | ||
default: | ||
parser.skipChildren(); | ||
break; | ||
} | ||
} | ||
if(question == null) { | ||
throw new IllegalArgumentException("Question is not provided"); | ||
Check warning on line 95 in common/src/main/java/org/opensearch/ml/common/input/nlp/QuestionAnsweringMLInput.java Codecov / codecov/patchcommon/src/main/java/org/opensearch/ml/common/input/nlp/QuestionAnsweringMLInput.java#L95
|
||
} | ||
if(context == null) { | ||
throw new IllegalArgumentException("Context is not provided"); | ||
Check warning on line 98 in common/src/main/java/org/opensearch/ml/common/input/nlp/QuestionAnsweringMLInput.java Codecov / codecov/patchcommon/src/main/java/org/opensearch/ml/common/input/nlp/QuestionAnsweringMLInput.java#L98
|
||
} | ||
inputDataset = new QuestionAnsweringInputDataSet(question, context); | ||
} | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
/* | ||
* Copyright 2023 Aryn | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.opensearch.ml.common.dataset; | ||
|
||
import org.junit.Test; | ||
import org.opensearch.common.io.stream.BytesStreamOutput; | ||
import org.opensearch.core.common.bytes.BytesReference; | ||
import org.opensearch.core.common.io.stream.BytesStreamInput; | ||
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
|
||
import java.io.IOException; | ||
import java.util.List; | ||
|
||
import static org.junit.Assert.assertThrows; | ||
|
||
public class QuestionAnsweringInputDatasetTest { | ||
|
||
@Test | ||
public void testStreaming() throws IOException { | ||
String question = "What color is apple"; | ||
String context = "I like Apples. They are red"; | ||
QuestionAnsweringInputDataSet dataset = QuestionAnsweringInputDataSet.builder().question(question).context(context).build(); | ||
BytesStreamOutput outbytes = new BytesStreamOutput(); | ||
StreamOutput osso = new OutputStreamStreamOutput(outbytes); | ||
dataset.writeTo(osso); | ||
StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); | ||
QuestionAnsweringInputDataSet newDs = (QuestionAnsweringInputDataSet) MLInputDataset.fromStream(in); | ||
assert (question.equals("What color is apple")); | ||
assert (context.equals("I like Apples. They are red")); | ||
} | ||
|
||
@Test | ||
public void noContext_ThenFail() { | ||
String question = "What color is apple"; | ||
IllegalArgumentException e = assertThrows(IllegalArgumentException.class, | ||
() -> QuestionAnsweringInputDataSet.builder().question(question).build()); | ||
assert (e.getMessage().equals("Context is not provided")); | ||
} | ||
|
||
@Test | ||
public void noQuestion_ThenFail() { | ||
String context = "I like Apples. They are red"; | ||
assertThrows(IllegalArgumentException.class, | ||
() -> QuestionAnsweringInputDataSet.builder().context(context).build()); | ||
} | ||
} |