Skip to content

Commit

Permalink
use local_regex as default type for guardrails (opensearch-project#2853)
Browse files Browse the repository at this point in the history
* use local_regex as default type for guardrails

Signed-off-by: Jing Zhang <[email protected]>

* add UT for model type

Signed-off-by: Jing Zhang <[email protected]>

---------

Signed-off-by: Jing Zhang <[email protected]>
  • Loading branch information
jngz-es authored Aug 28, 2024
1 parent 0a89537 commit 7ecff1a
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ public static Guardrails parse(XContentParser parser) throws IOException {
break;
}
}
if (type == null) {
type = "local_regex";
}
if (!validateType(type)) {
throw new IllegalArgumentException("The type of guardrails is required, can not be null.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import org.junit.Assert;
import org.junit.Before;
Expand All @@ -27,13 +28,17 @@ public class GuardrailsTests {
String[] regex;
LocalRegexGuardrail inputLocalRegexGuardrail;
LocalRegexGuardrail outputLocalRegexGuardrail;
ModelGuardrail inputModelGuardrail;
ModelGuardrail outputModelGuardrail;

@Before
public void setUp() {
stopWords = new StopWords("test_index", List.of("test_field").toArray(new String[0]));
regex = List.of("regex1").toArray(new String[0]);
inputLocalRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex);
outputLocalRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex);
inputModelGuardrail = new ModelGuardrail(Map.of("model_id", "guardrail_model_id", "response_validation_regex", "accept"));
outputModelGuardrail = new ModelGuardrail(Map.of("model_id", "guardrail_model_id", "response_validation_regex", "accept"));
}

@Test
Expand Down Expand Up @@ -83,4 +88,44 @@ public void parse() throws IOException {
Assert.assertEquals(guardrails.getInputGuardrail(), inputLocalRegexGuardrail);
Assert.assertEquals(guardrails.getOutputGuardrail(), outputLocalRegexGuardrail);
}

@Test
public void parseNonType() throws IOException {
String jsonStr = "{"
+ "\"input_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]},"
+ "\"output_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}}";
XContentParser parser = XContentType.JSON
.xContent()
.createParser(
new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()),
null,
jsonStr
);
parser.nextToken();
Guardrails guardrails = Guardrails.parse(parser);

Assert.assertEquals(guardrails.getType(), "local_regex");
Assert.assertEquals(guardrails.getInputGuardrail(), inputLocalRegexGuardrail);
Assert.assertEquals(guardrails.getOutputGuardrail(), outputLocalRegexGuardrail);
}

@Test
public void parseModelType() throws IOException {
String jsonStr = "{\"type\":\"model\","
+ "\"input_guardrail\":{\"model_id\":\"guardrail_model_id\",\"response_validation_regex\":\"accept\"},"
+ "\"output_guardrail\":{\"model_id\":\"guardrail_model_id\",\"response_validation_regex\":\"accept\"}}";
XContentParser parser = XContentType.JSON
.xContent()
.createParser(
new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()),
null,
jsonStr
);
parser.nextToken();
Guardrails guardrails = Guardrails.parse(parser);

Assert.assertEquals(guardrails.getType(), "model");
Assert.assertEquals(guardrails.getInputGuardrail(), inputModelGuardrail);
Assert.assertEquals(guardrails.getOutputGuardrail(), outputModelGuardrail);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,31 @@ public void testPredictRemoteModelFailed() throws IOException, InterruptedExcept
predictRemoteModel(modelId, predictInput);
}

public void testPredictRemoteModelFailedNonType() throws IOException, InterruptedException {
// Skip test if key is null
if (OPENAI_KEY == null) {
return;
}
exceptionRule.expect(ResponseException.class);
exceptionRule.expectMessage("guardrails triggered for user input");
Response response = createConnector(completionModelConnectorEntity);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModelNonTypeGuardrails("openAI-GPT-3.5 completions", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");
response = deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test of stop word.\"\n" + " }\n" + "}";
predictRemoteModel(modelId, predictInput);
}

public void testPredictRemoteModelSuccessWithModelGuardrail() throws IOException, InterruptedException {
// Skip test if key is null
if (OPENAI_KEY == null) {
Expand Down Expand Up @@ -437,6 +462,66 @@ protected Response registerRemoteModelWithLocalRegexGuardrails(String name, Stri
.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
}

protected Response registerRemoteModelNonTypeGuardrails(String name, String connectorId) throws IOException {
String registerModelGroupEntity = "{\n"
+ " \"name\": \"remote_model_group\",\n"
+ " \"description\": \"This is an example description\"\n"
+ "}";
Response response = TestHelper
.makeRequest(
client(),
"POST",
"/_plugins/_ml/model_groups/_register",
null,
TestHelper.toHttpEntity(registerModelGroupEntity),
null
);
Map responseMap = parseResponseToMap(response);
assertEquals((String) responseMap.get("status"), "CREATED");
String modelGroupId = (String) responseMap.get("model_group_id");

String registerModelEntity = "{\n"
+ " \"name\": \""
+ name
+ "\",\n"
+ " \"function_name\": \"remote\",\n"
+ " \"model_group_id\": \""
+ modelGroupId
+ "\",\n"
+ " \"version\": \"1.0.0\",\n"
+ " \"description\": \"test model\",\n"
+ " \"connector_id\": \""
+ connectorId
+ "\",\n"
+ " \"guardrails\": {\n"
+ " \"input_guardrail\": {\n"
+ " \"stop_words\": [\n"
+ " {"
+ " \"index_name\": \"stop_words\",\n"
+ " \"source_fields\": [\"title\"]\n"
+ " }"
+ " ],\n"
+ " \"regex\": [\"regex1\", \"regex2\"]\n"
+ " },\n"
+ " \"output_guardrail\": {\n"
+ " \"stop_words\": [\n"
+ " {"
+ " \"index_name\": \"stop_words\",\n"
+ " \"source_fields\": [\"title\"]\n"
+ " }"
+ " ],\n"
+ " \"regex\": [\"regex1\", \"regex2\"]\n"
+ " }\n"
+ "},\n"
+ " \"interface\": {\n"
+ " \"input\": {},\n"
+ " \"output\": {}\n"
+ " }\n"
+ "}";
return TestHelper
.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
}

protected Response registerRemoteModelWithModelGuardrails(String name, String connectorId, String guardrailModelId) throws IOException {

String registerModelGroupEntity = "{\n"
Expand Down

0 comments on commit 7ecff1a

Please sign in to comment.