Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

guardrails model support #2491

Merged
merged 5 commits into from
Jun 10, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
guardrails model support
Signed-off-by: Jing Zhang <jngz@amazon.com>
  • Loading branch information
jngz-es committed Jun 6, 2024
commit 7511359ee510554c5eed1fc31a626a8970a14184
101 changes: 7 additions & 94 deletions common/src/main/java/org/opensearch/ml/common/model/Guardrail.java
Original file line number Diff line number Diff line change
@@ -1,105 +1,18 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.model;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.client.Client;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

@EqualsAndHashCode
@Getter
public class Guardrail implements ToXContentObject {
public static final String STOP_WORDS_FIELD = "stop_words";
public static final String REGEX_FIELD = "regex";

private List<StopWords> stopWords;
private String[] regex;

@Builder(toBuilder = true)
public Guardrail(List<StopWords> stopWords, String[] regex) {
this.stopWords = stopWords;
this.regex = regex;
}

public Guardrail(StreamInput input) throws IOException {
if (input.readBoolean()) {
stopWords = new ArrayList<>();
int size = input.readInt();
for (int i=0; i<size; i++) {
stopWords.add(new StopWords(input));
}
}
regex = input.readStringArray();
}

public void writeTo(StreamOutput out) throws IOException {
if (stopWords != null && stopWords.size() > 0) {
out.writeBoolean(true);
out.writeInt(stopWords.size());
for (StopWords e : stopWords) {
e.writeTo(out);
}
} else {
out.writeBoolean(false);
}
out.writeStringArray(regex);
}
import java.util.Map;

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (stopWords != null && stopWords.size() > 0) {
builder.field(STOP_WORDS_FIELD, stopWords);
}
if (regex != null) {
builder.field(REGEX_FIELD, regex);
}
builder.endObject();
return builder;
}
public abstract class Guardrail implements ToXContentObject {

public static Guardrail parse(XContentParser parser) throws IOException {
List<StopWords> stopWords = null;
String[] regex = null;
public abstract void writeTo(StreamOutput out) throws IOException;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();
public abstract Boolean validate(String input, Map<String, String> parameters);

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why Boolean and not boolean?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to use object not primitive, as long as memory saving is not critical.

switch (fieldName) {
case STOP_WORDS_FIELD:
stopWords = new ArrayList<>();
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
stopWords.add(StopWords.parse(parser));
}
break;
case REGEX_FIELD:
regex = parser.list().toArray(new String[0]);
break;
default:
parser.skipChildren();
break;
}
}
return Guardrail.builder()
.stopWords(stopWords)
.regex(regex)
.build();
}
public abstract void init(NamedXContentRegistry xContentRegistry, Client client);
}
Original file line number Diff line number Diff line change
@@ -15,6 +15,8 @@
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;
import java.util.Map;
import java.util.Set;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

@@ -39,10 +41,26 @@ public Guardrails(String type, Guardrail inputGuardrail, Guardrail outputGuardra
public Guardrails(StreamInput input) throws IOException {
type = input.readString();
if (input.readBoolean()) {
inputGuardrail = new Guardrail(input);
switch (type) {
case "local_regex":
inputGuardrail = new LocalRegexGuardrail(input);
break;
case "model":
break;
default:
throw new IllegalArgumentException(String.format("Unsupported guardrails type: %s", type));
}
}
if (input.readBoolean()) {
outputGuardrail = new Guardrail(input);
switch (type) {
case "local_regex":
outputGuardrail = new LocalRegexGuardrail(input);
break;
case "model":
break;
default:
throw new IllegalArgumentException(String.format("Unsupported guardrails type: %s", type));
}
}
}

@@ -80,8 +98,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws

public static Guardrails parse(XContentParser parser) throws IOException {
String type = null;
Guardrail inputGuardrail = null;
Guardrail outputGuardrail = null;
Map<String, Object> inputGuardrailMap = null;
Map<String, Object> outputGuardrailMap = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -93,20 +111,47 @@ public static Guardrails parse(XContentParser parser) throws IOException {
type = parser.text();
break;
case INPUT_GUARDRAIL_FIELD:
inputGuardrail = Guardrail.parse(parser);
inputGuardrailMap = parser.map();
break;
case OUTPUT_GUARDRAIL_FIELD:
outputGuardrail = Guardrail.parse(parser);
outputGuardrailMap = parser.map();
break;
default:
parser.skipChildren();
break;
}
}
if (!validateType(type)) {
throw new IllegalArgumentException("The type of guardrails is required, can not be null.");
}

return Guardrails.builder()
.type(type)
.inputGuardrail(inputGuardrail)
.outputGuardrail(outputGuardrail)
.inputGuardrail(createGuardrail(type, inputGuardrailMap))
.outputGuardrail(createGuardrail(type, outputGuardrailMap))
.build();
}

private static Boolean validateType(String type) {
Set<String> types = Set.of("local_regex", "model");
jngz-es marked this conversation as resolved.
Show resolved Hide resolved
if (types.contains(type)) {
return true;
}
return false;
}

private static Guardrail createGuardrail(String type, Map<String, Object> params) {
if (params == null || params.isEmpty()) {
return null;
}

switch (type) {
case "local_regex":
return new LocalRegexGuardrail(params);
case "model":
return new ModelGuardrail(params);
default:
throw new IllegalArgumentException(String.format("Unsupported guardrails type: %s", type));
}
}
}
Loading
Loading