Skip to content

Commit

Permalink
guardrails model support
Browse files Browse the repository at this point in the history
Signed-off-by: Jing Zhang <[email protected]>
  • Loading branch information
jngz-es committed Jun 2, 2024
1 parent 9b072c4 commit 195d9a9
Show file tree
Hide file tree
Showing 10 changed files with 887 additions and 537 deletions.
100 changes: 6 additions & 94 deletions common/src/main/java/org/opensearch/ml/common/model/Guardrail.java
Original file line number Diff line number Diff line change
@@ -1,105 +1,17 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.model;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.client.Client;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

@EqualsAndHashCode
@Getter
public class Guardrail implements ToXContentObject {
public static final String STOP_WORDS_FIELD = "stop_words";
public static final String REGEX_FIELD = "regex";

private List<StopWords> stopWords;
private String[] regex;

@Builder(toBuilder = true)
public Guardrail(List<StopWords> stopWords, String[] regex) {
this.stopWords = stopWords;
this.regex = regex;
}

public Guardrail(StreamInput input) throws IOException {
if (input.readBoolean()) {
stopWords = new ArrayList<>();
int size = input.readInt();
for (int i=0; i<size; i++) {
stopWords.add(new StopWords(input));
}
}
regex = input.readStringArray();
}

public void writeTo(StreamOutput out) throws IOException {
if (stopWords != null && stopWords.size() > 0) {
out.writeBoolean(true);
out.writeInt(stopWords.size());
for (StopWords e : stopWords) {
e.writeTo(out);
}
} else {
out.writeBoolean(false);
}
out.writeStringArray(regex);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (stopWords != null && stopWords.size() > 0) {
builder.field(STOP_WORDS_FIELD, stopWords);
}
if (regex != null) {
builder.field(REGEX_FIELD, regex);
}
builder.endObject();
return builder;
}
public abstract class Guardrail implements ToXContentObject {

public static Guardrail parse(XContentParser parser) throws IOException {
List<StopWords> stopWords = null;
String[] regex = null;
public abstract void writeTo(StreamOutput out) throws IOException;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();
public abstract Boolean validate(String input);

switch (fieldName) {
case STOP_WORDS_FIELD:
stopWords = new ArrayList<>();
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
stopWords.add(StopWords.parse(parser));
}
break;
case REGEX_FIELD:
regex = parser.list().toArray(new String[0]);
break;
default:
parser.skipChildren();
break;
}
}
return Guardrail.builder()
.stopWords(stopWords)
.regex(regex)
.build();
}
public abstract void init(NamedXContentRegistry xContentRegistry, Client client);
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;
import java.util.Map;
import java.util.Set;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

Expand All @@ -39,10 +41,26 @@ public Guardrails(String type, Guardrail inputGuardrail, Guardrail outputGuardra
public Guardrails(StreamInput input) throws IOException {
type = input.readString();
if (input.readBoolean()) {
inputGuardrail = new Guardrail(input);
switch (type) {
case "local_regex":
inputGuardrail = new LocalRegexGuardrail(input);
break;
case "model":
break;
default:
throw new IllegalArgumentException(String.format("Unsupported guardrails type: %s", type));
}
}
if (input.readBoolean()) {
outputGuardrail = new Guardrail(input);
switch (type) {
case "local_regex":
outputGuardrail = new LocalRegexGuardrail(input);
break;
case "model":
break;
default:
throw new IllegalArgumentException(String.format("Unsupported guardrails type: %s", type));
}
}
}

Expand Down Expand Up @@ -80,8 +98,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws

public static Guardrails parse(XContentParser parser) throws IOException {
String type = null;
Guardrail inputGuardrail = null;
Guardrail outputGuardrail = null;
Map<String, Object> inputGuardrailMap = null;
Map<String, Object> outputGuardrailMap = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -93,20 +111,47 @@ public static Guardrails parse(XContentParser parser) throws IOException {
type = parser.text();
break;
case INPUT_GUARDRAIL_FIELD:
inputGuardrail = Guardrail.parse(parser);
inputGuardrailMap = parser.map();
break;
case OUTPUT_GUARDRAIL_FIELD:
outputGuardrail = Guardrail.parse(parser);
outputGuardrailMap = parser.map();
break;
default:
parser.skipChildren();
break;
}
}
if (!validateType(type)) {
throw new IllegalArgumentException("The type of guardrails is required, can not be null.");
}

return Guardrails.builder()
.type(type)
.inputGuardrail(inputGuardrail)
.outputGuardrail(outputGuardrail)
.inputGuardrail(createGuardrail(type, inputGuardrailMap))
.outputGuardrail(createGuardrail(type, outputGuardrailMap))
.build();
}

private static Boolean validateType(String type) {
Set<String> types = Set.of("local_regex", "model");
if (types.contains(type)) {
return true;
}
return false;
}

private static Guardrail createGuardrail(String type, Map<String, Object> params) {
if (params == null || params.isEmpty()) {
return null;
}

switch (type) {
case "local_regex":
return new LocalRegexGuardrail(params);
case "model":
return new ModelGuardrail(params);
default:
throw new IllegalArgumentException(String.format("Unsupported guardrails type: %s", type));
}
}
}
Loading

0 comments on commit 195d9a9

Please sign in to comment.