Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

guardrails model support #2491

Merged
merged 5 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@

package org.opensearch.ml.common;

import com.google.common.collect.ImmutableSet;
import org.opensearch.Version;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.connector.AbstractConnector;
import org.opensearch.ml.common.controller.MLController;

import java.util.Set;

import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.APPLICATION_TYPE_FIELD;
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD;
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD;
Expand Down Expand Up @@ -71,6 +74,7 @@ public class CommonValue {
public static final String ML_MEMORY_META_INDEX = ".plugins-ml-memory-meta";
public static final Integer ML_MEMORY_META_INDEX_SCHEMA_VERSION = 1;
public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message";
public static final Set<String> stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words");
public static final Integer ML_MEMORY_MESSAGE_INDEX_SCHEMA_VERSION = 1;
public static final String USER_FIELD_MAPPING = " \""
+ CommonValue.USER
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,101 +5,19 @@

package org.opensearch.ml.common.model;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.client.Client;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
public abstract class Guardrail implements ToXContentObject {

@EqualsAndHashCode
@Getter
public class Guardrail implements ToXContentObject {
public static final String STOP_WORDS_FIELD = "stop_words";
public static final String REGEX_FIELD = "regex";
public abstract void writeTo(StreamOutput out) throws IOException;

private List<StopWords> stopWords;
private String[] regex;
public abstract Boolean validate(String input, Map<String, String> parameters);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why Boolean and not boolean?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to use object not primitive, as long as memory saving is not critical.


@Builder(toBuilder = true)
public Guardrail(List<StopWords> stopWords, String[] regex) {
this.stopWords = stopWords;
this.regex = regex;
}

public Guardrail(StreamInput input) throws IOException {
if (input.readBoolean()) {
stopWords = new ArrayList<>();
int size = input.readInt();
for (int i=0; i<size; i++) {
stopWords.add(new StopWords(input));
}
}
regex = input.readStringArray();
}

public void writeTo(StreamOutput out) throws IOException {
if (stopWords != null && stopWords.size() > 0) {
out.writeBoolean(true);
out.writeInt(stopWords.size());
for (StopWords e : stopWords) {
e.writeTo(out);
}
} else {
out.writeBoolean(false);
}
out.writeStringArray(regex);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (stopWords != null && stopWords.size() > 0) {
builder.field(STOP_WORDS_FIELD, stopWords);
}
if (regex != null) {
builder.field(REGEX_FIELD, regex);
}
builder.endObject();
return builder;
}

public static Guardrail parse(XContentParser parser) throws IOException {
List<StopWords> stopWords = null;
String[] regex = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case STOP_WORDS_FIELD:
stopWords = new ArrayList<>();
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
stopWords.add(StopWords.parse(parser));
}
break;
case REGEX_FIELD:
regex = parser.list().toArray(new String[0]);
break;
default:
parser.skipChildren();
break;
}
}
return Guardrail.builder()
.stopWords(stopWords)
.regex(regex)
.build();
}
public abstract void init(NamedXContentRegistry xContentRegistry, Client client);
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;
import java.util.Map;
import java.util.Set;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

Expand All @@ -24,6 +26,7 @@ public class Guardrails implements ToXContentObject {
public static final String TYPE_FIELD = "type";
public static final String INPUT_GUARDRAIL_FIELD = "input_guardrail";
public static final String OUTPUT_GUARDRAIL_FIELD = "output_guardrail";
public static final Set<String> types = Set.of("local_regex", "model");

private String type;
private Guardrail inputGuardrail;
Expand All @@ -39,10 +42,26 @@ public Guardrails(String type, Guardrail inputGuardrail, Guardrail outputGuardra
public Guardrails(StreamInput input) throws IOException {
type = input.readString();
if (input.readBoolean()) {
inputGuardrail = new Guardrail(input);
switch (type) {
case "local_regex":
inputGuardrail = new LocalRegexGuardrail(input);
break;
case "model":
break;
default:
throw new IllegalArgumentException(String.format("Unsupported guardrails type: %s", type));
}
}
if (input.readBoolean()) {
outputGuardrail = new Guardrail(input);
switch (type) {
case "local_regex":
outputGuardrail = new LocalRegexGuardrail(input);
break;
case "model":
break;
default:
throw new IllegalArgumentException(String.format("Unsupported guardrails type: %s", type));
}
}
}

Expand Down Expand Up @@ -80,8 +99,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws

public static Guardrails parse(XContentParser parser) throws IOException {
String type = null;
Guardrail inputGuardrail = null;
Guardrail outputGuardrail = null;
Map<String, Object> inputGuardrailMap = null;
Map<String, Object> outputGuardrailMap = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -93,20 +112,46 @@ public static Guardrails parse(XContentParser parser) throws IOException {
type = parser.text();
break;
case INPUT_GUARDRAIL_FIELD:
inputGuardrail = Guardrail.parse(parser);
inputGuardrailMap = parser.map();
break;
case OUTPUT_GUARDRAIL_FIELD:
outputGuardrail = Guardrail.parse(parser);
outputGuardrailMap = parser.map();
break;
default:
parser.skipChildren();
break;
}
}
if (!validateType(type)) {
throw new IllegalArgumentException("The type of guardrails is required, can not be null.");
}

return Guardrails.builder()
.type(type)
.inputGuardrail(inputGuardrail)
.outputGuardrail(outputGuardrail)
.inputGuardrail(createGuardrail(type, inputGuardrailMap))
.outputGuardrail(createGuardrail(type, outputGuardrailMap))
.build();
}

private static Boolean validateType(String type) {
if (types.contains(type)) {
return true;
}
return false;
}

private static Guardrail createGuardrail(String type, Map<String, Object> params) {
if (params == null || params.isEmpty()) {
return null;
}

switch (type) {
case "local_regex":
return new LocalRegexGuardrail(params);
case "model":
return new ModelGuardrail(params);
default:
throw new IllegalArgumentException(String.format("Unsupported guardrails type: %s", type));
}
}
}
Loading
Loading