Skip to content

Commit

Permalink
guardrails model support (#2491) (#2526)
Browse files Browse the repository at this point in the history
* guardrails model support

* add IT for remote guardrails model

* address comments

* address more comments 1

Signed-off-by: Jing Zhang <[email protected]>
(cherry picked from commit 40c5edc)
  • Loading branch information
jngz-es authored and github-actions[bot] committed Oct 1, 2024
1 parent 8d9e41f commit c58168b
Show file tree
Hide file tree
Showing 15 changed files with 1,259 additions and 539 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@

package org.opensearch.ml.common;

import com.google.common.collect.ImmutableSet;
import org.opensearch.Version;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.connector.AbstractConnector;
import org.opensearch.ml.common.controller.MLController;

import java.util.Set;

import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.APPLICATION_TYPE_FIELD;
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD;
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD;
Expand Down Expand Up @@ -73,6 +76,7 @@ public class CommonValue {
public static final String ML_MEMORY_META_INDEX = ".plugins-ml-memory-meta";
public static final Integer ML_MEMORY_META_INDEX_SCHEMA_VERSION = 1;
public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message";
public static final Set<String> stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words");
public static final Integer ML_MEMORY_MESSAGE_INDEX_SCHEMA_VERSION = 1;
public static final String USER_FIELD_MAPPING = " \""
+ CommonValue.USER
Expand Down
96 changes: 7 additions & 89 deletions common/src/main/java/org/opensearch/ml/common/model/Guardrail.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,101 +5,19 @@

package org.opensearch.ml.common.model;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.client.Client;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
public abstract class Guardrail implements ToXContentObject {

@EqualsAndHashCode
@Getter
public class Guardrail implements ToXContentObject {
public static final String STOP_WORDS_FIELD = "stop_words";
public static final String REGEX_FIELD = "regex";
public abstract void writeTo(StreamOutput out) throws IOException;

private List<StopWords> stopWords;
private String[] regex;
public abstract Boolean validate(String input, Map<String, String> parameters);

@Builder(toBuilder = true)
public Guardrail(List<StopWords> stopWords, String[] regex) {
this.stopWords = stopWords;
this.regex = regex;
}

public Guardrail(StreamInput input) throws IOException {
if (input.readBoolean()) {
stopWords = new ArrayList<>();
int size = input.readInt();
for (int i=0; i<size; i++) {
stopWords.add(new StopWords(input));
}
}
regex = input.readStringArray();
}

public void writeTo(StreamOutput out) throws IOException {
if (stopWords != null && stopWords.size() > 0) {
out.writeBoolean(true);
out.writeInt(stopWords.size());
for (StopWords e : stopWords) {
e.writeTo(out);
}
} else {
out.writeBoolean(false);
}
out.writeStringArray(regex);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (stopWords != null && stopWords.size() > 0) {
builder.field(STOP_WORDS_FIELD, stopWords);
}
if (regex != null) {
builder.field(REGEX_FIELD, regex);
}
builder.endObject();
return builder;
}

public static Guardrail parse(XContentParser parser) throws IOException {
List<StopWords> stopWords = null;
String[] regex = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case STOP_WORDS_FIELD:
stopWords = new ArrayList<>();
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
stopWords.add(StopWords.parse(parser));
}
break;
case REGEX_FIELD:
regex = parser.list().toArray(new String[0]);
break;
default:
parser.skipChildren();
break;
}
}
return Guardrail.builder()
.stopWords(stopWords)
.regex(regex)
.build();
}
public abstract void init(NamedXContentRegistry xContentRegistry, Client client);
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;
import java.util.Map;
import java.util.Set;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

Expand All @@ -24,6 +26,7 @@ public class Guardrails implements ToXContentObject {
public static final String TYPE_FIELD = "type";
public static final String INPUT_GUARDRAIL_FIELD = "input_guardrail";
public static final String OUTPUT_GUARDRAIL_FIELD = "output_guardrail";
public static final Set<String> types = Set.of("local_regex", "model");

private String type;
private Guardrail inputGuardrail;
Expand All @@ -39,10 +42,26 @@ public Guardrails(String type, Guardrail inputGuardrail, Guardrail outputGuardra
public Guardrails(StreamInput input) throws IOException {
type = input.readString();
if (input.readBoolean()) {
inputGuardrail = new Guardrail(input);
switch (type) {
case "local_regex":
inputGuardrail = new LocalRegexGuardrail(input);
break;
case "model":
break;
default:
throw new IllegalArgumentException(String.format("Unsupported guardrails type: %s", type));
}
}
if (input.readBoolean()) {
outputGuardrail = new Guardrail(input);
switch (type) {
case "local_regex":
outputGuardrail = new LocalRegexGuardrail(input);
break;
case "model":
break;
default:
throw new IllegalArgumentException(String.format("Unsupported guardrails type: %s", type));
}
}
}

Expand Down Expand Up @@ -80,8 +99,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws

public static Guardrails parse(XContentParser parser) throws IOException {
String type = null;
Guardrail inputGuardrail = null;
Guardrail outputGuardrail = null;
Map<String, Object> inputGuardrailMap = null;
Map<String, Object> outputGuardrailMap = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -93,20 +112,46 @@ public static Guardrails parse(XContentParser parser) throws IOException {
type = parser.text();
break;
case INPUT_GUARDRAIL_FIELD:
inputGuardrail = Guardrail.parse(parser);
inputGuardrailMap = parser.map();
break;
case OUTPUT_GUARDRAIL_FIELD:
outputGuardrail = Guardrail.parse(parser);
outputGuardrailMap = parser.map();
break;
default:
parser.skipChildren();
break;
}
}
if (!validateType(type)) {
throw new IllegalArgumentException("The type of guardrails is required, can not be null.");
}

return Guardrails.builder()
.type(type)
.inputGuardrail(inputGuardrail)
.outputGuardrail(outputGuardrail)
.inputGuardrail(createGuardrail(type, inputGuardrailMap))
.outputGuardrail(createGuardrail(type, outputGuardrailMap))
.build();
}

private static Boolean validateType(String type) {
if (types.contains(type)) {
return true;
}
return false;
}

private static Guardrail createGuardrail(String type, Map<String, Object> params) {
if (params == null || params.isEmpty()) {
return null;
}

switch (type) {
case "local_regex":
return new LocalRegexGuardrail(params);
case "model":
return new ModelGuardrail(params);
default:
throw new IllegalArgumentException(String.format("Unsupported guardrails type: %s", type));
}
}
}
Loading

0 comments on commit c58168b

Please sign in to comment.