From 195d9a9feec291ccd967f5f773ac9fa9ec23416a Mon Sep 17 00:00:00 2001 From: Jing Zhang Date: Fri, 31 May 2024 10:23:13 -0700 Subject: [PATCH] guardrails model support Signed-off-by: Jing Zhang --- .../opensearch/ml/common/model/Guardrail.java | 100 +------ .../ml/common/model/Guardrails.java | 61 +++- .../ml/common/model/LocalRegexGuardrail.java | 267 +++++++++++++++++ .../opensearch/ml/common/model/MLGuard.java | 163 +---------- .../ml/common/model/ModelGuardrail.java | 257 +++++++++++++++++ .../opensearch/ml/common/model/StopWords.java | 7 + .../ml/common/model/GuardrailTests.java | 69 ----- .../ml/common/model/GuardrailsTests.java | 24 +- .../model/LocalRegexGuardrailTests.java | 270 ++++++++++++++++++ .../ml/common/model/MLGuardTests.java | 206 +------------ 10 files changed, 887 insertions(+), 537 deletions(-) create mode 100644 common/src/main/java/org/opensearch/ml/common/model/LocalRegexGuardrail.java create mode 100644 common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java delete mode 100644 common/src/test/java/org/opensearch/ml/common/model/GuardrailTests.java create mode 100644 common/src/test/java/org/opensearch/ml/common/model/LocalRegexGuardrailTests.java diff --git a/common/src/main/java/org/opensearch/ml/common/model/Guardrail.java b/common/src/main/java/org/opensearch/ml/common/model/Guardrail.java index d690fdce7f..03554c3e48 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/Guardrail.java +++ b/common/src/main/java/org/opensearch/ml/common/model/Guardrail.java @@ -1,105 +1,17 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - package org.opensearch.ml.common.model; -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.client.Client; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContentObject; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; - -@EqualsAndHashCode -@Getter -public class Guardrail implements ToXContentObject { - public static final String STOP_WORDS_FIELD = "stop_words"; - public static final String REGEX_FIELD = "regex"; - - private List stopWords; - private String[] regex; - - @Builder(toBuilder = true) - public Guardrail(List stopWords, String[] regex) { - this.stopWords = stopWords; - this.regex = regex; - } - - public Guardrail(StreamInput input) throws IOException { - if (input.readBoolean()) { - stopWords = new ArrayList<>(); - int size = input.readInt(); - for (int i=0; i 0) { - out.writeBoolean(true); - out.writeInt(stopWords.size()); - for (StopWords e : stopWords) { - e.writeTo(out); - } - } else { - out.writeBoolean(false); - } - out.writeStringArray(regex); - } - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - if (stopWords != null && stopWords.size() > 0) { - builder.field(STOP_WORDS_FIELD, stopWords); - } - if (regex != null) { - builder.field(REGEX_FIELD, regex); - } - builder.endObject(); - return builder; - } +public abstract class Guardrail implements ToXContentObject { - public static Guardrail parse(XContentParser parser) throws IOException { - List stopWords = null; - String[] regex = null; + public abstract void writeTo(StreamOutput out) throws IOException; - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - String fieldName = parser.currentName(); - parser.nextToken(); + public abstract Boolean validate(String input); - switch (fieldName) { - case STOP_WORDS_FIELD: - stopWords = new ArrayList<>(); - ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_ARRAY) { - stopWords.add(StopWords.parse(parser)); - } - break; - case REGEX_FIELD: - regex = parser.list().toArray(new String[0]); - break; - default: - parser.skipChildren(); - break; - } - } - return Guardrail.builder() - .stopWords(stopWords) - .regex(regex) - .build(); - } + public abstract void init(NamedXContentRegistry xContentRegistry, Client client); } diff --git a/common/src/main/java/org/opensearch/ml/common/model/Guardrails.java b/common/src/main/java/org/opensearch/ml/common/model/Guardrails.java index 1153262935..758f38d43c 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/Guardrails.java +++ b/common/src/main/java/org/opensearch/ml/common/model/Guardrails.java @@ -15,6 +15,8 @@ import org.opensearch.core.xcontent.XContentParser; import java.io.IOException; +import java.util.Map; +import java.util.Set; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; @@ -39,10 +41,26 @@ public Guardrails(String type, Guardrail inputGuardrail, Guardrail outputGuardra public Guardrails(StreamInput input) throws IOException { type = input.readString(); if (input.readBoolean()) { - inputGuardrail = new Guardrail(input); + switch (type) { + case "local_regex": + inputGuardrail = new LocalRegexGuardrail(input); + break; + case "model": + break; + default: + throw new IllegalArgumentException(String.format("Unsupported guardrails type: %s", type)); + } } if (input.readBoolean()) { - outputGuardrail = new Guardrail(input); + switch (type) { + case "local_regex": + outputGuardrail = new LocalRegexGuardrail(input); + break; + case "model": + break; + default: + throw new IllegalArgumentException(String.format("Unsupported guardrails type: %s", type)); + } } } @@ -80,8 +98,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws public static Guardrails parse(XContentParser parser) throws IOException { String type = null; - Guardrail inputGuardrail = null; - Guardrail outputGuardrail = null; + Map inputGuardrailMap = null; + Map outputGuardrailMap = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -93,20 +111,47 @@ public static Guardrails parse(XContentParser parser) throws IOException { type = parser.text(); break; case INPUT_GUARDRAIL_FIELD: - inputGuardrail = Guardrail.parse(parser); + inputGuardrailMap = parser.map(); break; case OUTPUT_GUARDRAIL_FIELD: - outputGuardrail = Guardrail.parse(parser); + outputGuardrailMap = parser.map(); break; default: parser.skipChildren(); break; } } + if (!validateType(type)) { + throw new IllegalArgumentException("The type of guardrails is required, can not be null."); + } + return Guardrails.builder() .type(type) - .inputGuardrail(inputGuardrail) - .outputGuardrail(outputGuardrail) + .inputGuardrail(createGuardrail(type, inputGuardrailMap)) + .outputGuardrail(createGuardrail(type, outputGuardrailMap)) .build(); } + + private static Boolean validateType(String type) { + Set types = Set.of("local_regex", "model"); + if (types.contains(type)) { + return true; + } + return false; + } + + private static Guardrail createGuardrail(String type, Map params) { + if (params == null || params.isEmpty()) { + return null; + } + + switch (type) { + case "local_regex": + return new LocalRegexGuardrail(params); + case "model": + return new ModelGuardrail(params); + default: + throw new IllegalArgumentException(String.format("Unsupported guardrails type: %s", type)); + } + } } diff --git a/common/src/main/java/org/opensearch/ml/common/model/LocalRegexGuardrail.java b/common/src/main/java/org/opensearch/ml/common/model/LocalRegexGuardrail.java new file mode 100644 index 0000000000..a71e71e73e --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/model/LocalRegexGuardrail.java @@ -0,0 +1,267 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.model; + +import com.google.common.collect.ImmutableSet; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NonNull; +import lombok.extern.log4j.Log4j2; +import org.opensearch.action.LatchedActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.builder.SearchSourceBuilder; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedExceptionAction; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.gson; + +@Log4j2 +@EqualsAndHashCode +@Getter +public class LocalRegexGuardrail extends Guardrail { + public static final String STOP_WORDS_FIELD = "stop_words"; + public static final String REGEX_FIELD = "regex"; + + private List stopWords; + private String[] regex; + private List regexPattern; + private Map> stopWordsIndicesInput; + private NamedXContentRegistry xContentRegistry; + private Client client; + private Set stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words"); + + @Builder(toBuilder = true) + public LocalRegexGuardrail(List stopWords, String[] regex) { + this.stopWords = stopWords; + this.regex = regex; + } + public LocalRegexGuardrail(@NonNull Map params) { + Object words = params.get(STOP_WORDS_FIELD); + stopWords = new ArrayList<>(); + for (Map e : (List)words) { + stopWords.add(new StopWords(e)); + } + this.regex = ((List) params.get(REGEX_FIELD)).toArray(new String[0]); + } + + public LocalRegexGuardrail(StreamInput input) throws IOException { + if (input.readBoolean()) { + stopWords = new ArrayList<>(); + int size = input.readInt(); + for (int i=0; i 0) { + out.writeBoolean(true); + out.writeInt(stopWords.size()); + for (StopWords e : stopWords) { + e.writeTo(out); + } + } else { + out.writeBoolean(false); + } + out.writeStringArray(regex); + } + + @Override + public Boolean validate(String input) { + return validateRegexList(input, regexPattern) && validateStopWords(input, stopWordsIndicesInput); + } + + @Override + public void init(NamedXContentRegistry xContentRegistry, Client client) { + this.xContentRegistry = xContentRegistry; + this.client = client; + init(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (stopWords != null && stopWords.size() > 0) { + builder.field(STOP_WORDS_FIELD, stopWords); + } + if (regex != null) { + builder.field(REGEX_FIELD, regex); + } + builder.endObject(); + return builder; + } + + public static LocalRegexGuardrail parse(XContentParser parser) throws IOException { + List stopWords = null; + String[] regex = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case STOP_WORDS_FIELD: + stopWords = new ArrayList<>(); + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + stopWords.add(StopWords.parse(parser)); + } + break; + case REGEX_FIELD: + regex = parser.list().toArray(new String[0]); + break; + default: + parser.skipChildren(); + break; + } + } + return LocalRegexGuardrail.builder() + .stopWords(stopWords) + .regex(regex) + .build(); + } + + private void init() { + stopWordsIndicesInput = stopWordsToMap(); + List regexList = regex == null ? new ArrayList<>() : Arrays.asList(regex); + regexPattern = regexList.stream().map(reg -> Pattern.compile(reg)).collect(Collectors.toList()); + } + + private Map> stopWordsToMap() { + Map> map = new HashMap<>(); + if (stopWords != null && !stopWords.isEmpty()) { + for (StopWords e : stopWords) { + if (e.getIndex() != null && e.getSourceFields() != null) { + map.put(e.getIndex(), Arrays.asList(e.getSourceFields())); + } + } + } + return map; + } + + public Boolean validateRegexList(String input, List regexPatterns) { + if (regexPatterns == null || regexPatterns.isEmpty()) { + return true; + } + for (Pattern pattern : regexPatterns) { + if (!validateRegex(input, pattern)) { + return false; + } + } + return true; + } + + public Boolean validateRegex(String input, Pattern pattern) { + Matcher matcher = pattern.matcher(input); + return !matcher.matches(); + } + + public Boolean validateStopWords(String input, Map> stopWordsIndices) { + if (stopWordsIndices == null || stopWordsIndices.isEmpty()) { + return true; + } + for (Map.Entry entry : stopWordsIndices.entrySet()) { + if (!validateStopWordsSingleIndex(input, (String) entry.getKey(), (List) entry.getValue())) { + return false; + } + } + return true; + } + + public Boolean validateStopWordsSingleIndex(String input, String indexName, List fieldNames) { + SearchRequest searchRequest; + AtomicBoolean hitStopWords = new AtomicBoolean(false); + String queryBody; + Map documentMap = new HashMap<>(); + for (String field : fieldNames) { + documentMap.put(field, input); + } + Map queryBodyMap = Map + .of("query", Map.of("percolate", Map.of("field", "query", "document", documentMap))); + CountDownLatch latch = new CountDownLatch(1); + ThreadContext.StoredContext context = null; + + try { + queryBody = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(queryBodyMap)); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, queryBody); + searchSourceBuilder.parseXContent(queryParser); + searchSourceBuilder.size(1); //Only need 1 doc returned, if hit. + searchRequest = new SearchRequest().source(searchSourceBuilder).indices(indexName); + if (isStopWordsSystemIndex(indexName)) { + context = client.threadPool().getThreadContext().stashContext(); + ThreadContext.StoredContext finalContext = context; + client.search(searchRequest, ActionListener.runBefore(new LatchedActionListener(ActionListener.wrap(r -> { + if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) { + hitStopWords.set(true); + } + }, e -> { + log.error("Failed to search stop words index {}", indexName, e); + hitStopWords.set(true); + }), latch), () -> finalContext.restore())); + } else { + client.search(searchRequest, new LatchedActionListener(ActionListener.wrap(r -> { + if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) { + hitStopWords.set(true); + } + }, e -> { + log.error("Failed to search stop words index {}", indexName, e); + hitStopWords.set(true); + }), latch)); + } + } catch (Exception e) { + log.error("[validateStopWords] Searching stop words index failed.", e); + latch.countDown(); + hitStopWords.set(true); + } finally { + if (context != null) { + context.close(); + } + } + + try { + latch.await(5, SECONDS); + } catch (InterruptedException e) { + log.error("[validateStopWords] Searching stop words index was timeout.", e); + throw new IllegalStateException(e); + } + return hitStopWords.get(); + } + + private boolean isStopWordsSystemIndex(String index) { + return stopWordsIndices.contains(index); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/model/MLGuard.java b/common/src/main/java/org/opensearch/ml/common/model/MLGuard.java index dcb0e65ad7..8d7fa9f30a 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/MLGuard.java +++ b/common/src/main/java/org/opensearch/ml/common/model/MLGuard.java @@ -5,188 +5,41 @@ package org.opensearch.ml.common.model; -import com.google.common.collect.ImmutableSet; import lombok.Getter; -import lombok.NonNull; import lombok.extern.log4j.Log4j2; -import org.opensearch.action.LatchedActionListener; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.search.builder.SearchSourceBuilder; - -import java.security.AccessController; -import java.security.PrivilegedExceptionAction; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.regex.Matcher; -import java.util.regex.Pattern; -import java.util.stream.Collectors; - -import static java.util.concurrent.TimeUnit.SECONDS; -import static org.opensearch.ml.common.utils.StringUtils.gson; @Log4j2 @Getter public class MLGuard { - private Map> stopWordsIndicesInput = new HashMap<>(); - private Map> stopWordsIndicesOutput = new HashMap<>(); - private List inputRegex; - private List outputRegex; - private List inputRegexPattern; - private List outputRegexPattern; private NamedXContentRegistry xContentRegistry; private Client client; - private Set stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words"); + private Guardrails guardrails; public MLGuard(Guardrails guardrails, NamedXContentRegistry xContentRegistry, Client client) { this.xContentRegistry = xContentRegistry; this.client = client; - if (guardrails == null) { - return; - } - Guardrail inputGuardrail = guardrails.getInputGuardrail(); - Guardrail outputGuardrail = guardrails.getOutputGuardrail(); - if (inputGuardrail != null) { - fillStopWordsToMap(inputGuardrail, stopWordsIndicesInput); - inputRegex = inputGuardrail.getRegex() == null ? new ArrayList<>() : Arrays.asList(inputGuardrail.getRegex()); - inputRegexPattern = inputRegex.stream().map(reg -> Pattern.compile(reg)).collect(Collectors.toList()); - } - if (outputGuardrail != null) { - fillStopWordsToMap(outputGuardrail, stopWordsIndicesOutput); - outputRegex = outputGuardrail.getRegex() == null ? new ArrayList<>() : Arrays.asList(outputGuardrail.getRegex()); - outputRegexPattern = outputRegex.stream().map(reg -> Pattern.compile(reg)).collect(Collectors.toList()); - } - } - - private void fillStopWordsToMap(@NonNull Guardrail guardrail, Map> map) { - List stopWords = guardrail.getStopWords(); - if (stopWords == null || stopWords.isEmpty()) { - return; + this.guardrails = guardrails; + if (this.guardrails != null && this.guardrails.getInputGuardrail() != null) { + this.guardrails.getInputGuardrail().init(xContentRegistry, client); } - for (StopWords e : stopWords) { - if (e.getIndex() != null && e.getSourceFields() != null) { - map.put(e.getIndex(), Arrays.asList(e.getSourceFields())); - } + if (this.guardrails != null && this.guardrails.getOutputGuardrail() != null) { + this.guardrails.getOutputGuardrail().init(xContentRegistry, client); } } public Boolean validate(String input, Type type) { switch (type) { case INPUT: // validate input - return validateRegexList(input, inputRegexPattern) && validateStopWords(input, stopWordsIndicesInput); + return guardrails.getInputGuardrail() == null ? true : guardrails.getInputGuardrail().validate(input); case OUTPUT: // validate output - return validateRegexList(input, outputRegexPattern) && validateStopWords(input, stopWordsIndicesOutput); + return guardrails.getOutputGuardrail() == null ? true : guardrails.getOutputGuardrail().validate(input); default: throw new IllegalArgumentException("Unsupported type to validate for guardrails."); } } - public Boolean validateRegexList(String input, List regexPatterns) { - if (regexPatterns == null || regexPatterns.isEmpty()) { - return true; - } - for (Pattern pattern : regexPatterns) { - if (!validateRegex(input, pattern)) { - return false; - } - } - return true; - } - - public Boolean validateRegex(String input, Pattern pattern) { - Matcher matcher = pattern.matcher(input); - return !matcher.matches(); - } - - public Boolean validateStopWords(String input, Map> stopWordsIndices) { - if (stopWordsIndices == null || stopWordsIndices.isEmpty()) { - return true; - } - for (Map.Entry entry : stopWordsIndices.entrySet()) { - if (!validateStopWordsSingleIndex(input, (String) entry.getKey(), (List) entry.getValue())) { - return false; - } - } - return true; - } - - public Boolean validateStopWordsSingleIndex(String input, String indexName, List fieldNames) { - SearchRequest searchRequest; - AtomicBoolean hitStopWords = new AtomicBoolean(false); - String queryBody; - Map documentMap = new HashMap<>(); - for (String field : fieldNames) { - documentMap.put(field, input); - } - Map queryBodyMap = Map - .of("query", Map.of("percolate", Map.of("field", "query", "document", documentMap))); - CountDownLatch latch = new CountDownLatch(1); - ThreadContext.StoredContext context = null; - - try { - queryBody = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(queryBodyMap)); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, queryBody); - searchSourceBuilder.parseXContent(queryParser); - searchSourceBuilder.size(1); //Only need 1 doc returned, if hit. - searchRequest = new SearchRequest().source(searchSourceBuilder).indices(indexName); - if (isStopWordsSystemIndex(indexName)) { - context = client.threadPool().getThreadContext().stashContext(); - ThreadContext.StoredContext finalContext = context; - client.search(searchRequest, ActionListener.runBefore(new LatchedActionListener(ActionListener.wrap(r -> { - if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) { - hitStopWords.set(true); - } - }, e -> { - log.error("Failed to search stop words index {}", indexName, e); - hitStopWords.set(true); - }), latch), () -> finalContext.restore())); - } else { - client.search(searchRequest, new LatchedActionListener(ActionListener.wrap(r -> { - if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) { - hitStopWords.set(true); - } - }, e -> { - log.error("Failed to search stop words index {}", indexName, e); - hitStopWords.set(true); - }), latch)); - } - } catch (Exception e) { - log.error("[validateStopWords] Searching stop words index failed.", e); - latch.countDown(); - hitStopWords.set(true); - } finally { - if (context != null) { - context.close(); - } - } - - try { - latch.await(5, SECONDS); - } catch (InterruptedException e) { - log.error("[validateStopWords] Searching stop words index was timeout.", e); - throw new IllegalStateException(e); - } - return hitStopWords.get(); - } - - private boolean isStopWordsSystemIndex(String index) { - return stopWordsIndices.contains(index); - } - public enum Type { INPUT, OUTPUT diff --git a/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java b/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java new file mode 100644 index 0000000000..7adcd01f6e --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/model/ModelGuardrail.java @@ -0,0 +1,257 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.model; + +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NonNull; +import lombok.extern.log4j.Log4j2; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.LatchedActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.client.Client; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.search.fetch.subphase.FetchSourceContext; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.common.MLModel.ALGORITHM_FIELD; +import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; + +@Log4j2 +@EqualsAndHashCode +@Getter +public class ModelGuardrail extends Guardrail { + public static final String MODEL_ID_FIELD = "model_id"; + public static final String RESPONSE_FILTER_FIELD = "response_filter"; + public static final String RESPONSE_ACCEPT_FIELD = "response_accept"; + public static final String PARAM_FIELD = "parameters"; + + private String modelId; + private String responseFilter; + private String responseAccept; + private Map parameters; + private NamedXContentRegistry xContentRegistry; + private Client client; + private Pattern regexAcceptPattern; + + @Builder(toBuilder = true) + public ModelGuardrail(String modelId, String responseFilter, String responseAccept, Map parameters) { + this.modelId = modelId; + this.responseFilter = responseFilter; + this.responseAccept = responseAccept; + this.parameters = parameters; + } + public ModelGuardrail(@NonNull Map params) { + this((String) params.get(MODEL_ID_FIELD), (String) params.get(RESPONSE_FILTER_FIELD), (String) params.get(RESPONSE_ACCEPT_FIELD), (Map) params.get(PARAM_FIELD)); + } + + public ModelGuardrail(StreamInput input) throws IOException { + modelId = input.readString(); + responseFilter = input.readString(); + responseAccept = input.readString(); + if (input.readBoolean()) { + parameters = input.readMap(s -> s.readString(), s-> s.readString()); + } + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeString(responseFilter); + out.writeString(responseAccept); + if (parameters != null && parameters.size() > 0) { + out.writeBoolean(true); + out.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeString); + } else { + out.writeBoolean(false); + } + } + + private Boolean validateAcceptRegex(String input) { + Matcher matcher = regexAcceptPattern.matcher(input); + return matcher.matches(); + } + + @Override + public Boolean validate(String in) { + String input = StringUtils.processTextDoc(in); + AtomicBoolean isAccepted = new AtomicBoolean(true); + ActionListener internalListener = ActionListener.wrap(predictionResponse -> { + ModelTensorOutput output = (ModelTensorOutput) predictionResponse.getOutput(); + ModelTensor tensor = output.getMlModelOutputs().get(0).getMlModelTensors().get(0); + String guardrailResponse = (String) tensor.getDataAsMap().get(responseFilter); + log.info("Guardrail response: {}", guardrailResponse); + if (!validateAcceptRegex(guardrailResponse)) { + isAccepted.set(false); + } + }, e -> {log.error("[ModelGuardrail] Failed to get prediction response.", e);}); + ActionListener actionListener = wrapActionListener(internalListener, res -> { + MLTaskResponse predictionResponse = MLTaskResponse.fromActionResponse(res); + return predictionResponse; + }); + CountDownLatch latch = new CountDownLatch(1); + ActionRequest request = new MLPredictionTaskRequest( + modelId, + RemoteInferenceMLInput + .builder() + .algorithm(FunctionName.REMOTE) + .inputDataset(RemoteInferenceInputDataSet.builder().parameters(Map.of("question", input)).build()) + .build() + ); + client + .execute( + MLPredictionTaskAction.INSTANCE, + request, + new LatchedActionListener(actionListener, latch) + ); + try { + latch.await(5, SECONDS); + } catch (InterruptedException e) { + log.error("[ModelGuardrail] Validation was timeout.", e); + throw new IllegalStateException(e); + } + + return isAccepted.get(); + } + + @Override + public void init(NamedXContentRegistry xContentRegistry, Client client) { + this.xContentRegistry = xContentRegistry; + this.client = client; + regexAcceptPattern = Pattern.compile(responseAccept); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (modelId != null) { + builder.field(MODEL_ID_FIELD, modelId); + } + if (responseFilter != null) { + builder.field(RESPONSE_FILTER_FIELD, responseFilter); + } + if (responseAccept != null) { + builder.field(RESPONSE_ACCEPT_FIELD, responseAccept); + } + if (parameters != null && parameters.size() > 0) { + builder.field(PARAM_FIELD, parameters); + } + builder.endObject(); + return builder; + } + + public static ModelGuardrail parse(XContentParser parser) throws IOException { + String modelId = null; + String responseFilter = null; + String responseAccept = null; + Map parameters = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case MODEL_ID_FIELD: + modelId = parser.text(); + break; + case RESPONSE_FILTER_FIELD: + responseFilter = parser.text(); + break; + case RESPONSE_ACCEPT_FIELD: + responseAccept = parser.text(); + break; + case PARAM_FIELD: + parameters = getParameterMap(parser.map()); + break; + default: + parser.skipChildren(); + break; + } + } + return ModelGuardrail.builder() + .modelId(modelId) + .responseFilter(responseFilter) + .responseAccept(responseAccept) + .parameters(parameters) + .build(); + } + + private ActionListener wrapActionListener( + final ActionListener listener, + final Function recreate + ) { + ActionListener actionListener = ActionListener.wrap(r -> { + listener.onResponse(recreate.apply(r)); + ; + }, e -> { listener.onFailure(e); }); + return actionListener; + } + + private XContentParser createXContentParserFromRegistry(NamedXContentRegistry xContentRegistry, BytesReference bytesReference) + throws IOException { + return XContentHelper.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, bytesReference, XContentType.JSON); + } + + private void getModel(String modelId, String[] includes, String[] excludes, ActionListener listener) { + GetRequest getRequest = new GetRequest(); + FetchSourceContext fetchContext = new FetchSourceContext(true, includes, excludes); + getRequest.index(ML_MODEL_INDEX).id(modelId).fetchSourceContext(fetchContext); + client.get(getRequest, ActionListener.wrap(r -> { + if (r != null && r.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + String algorithmName = r.getSource().get(ALGORITHM_FIELD).toString(); + + MLModel mlModel = MLModel.parse(parser, algorithmName); + mlModel.setModelId(modelId); + listener.onResponse(mlModel); + } catch (Exception e) { + log.error("Failed to parse ml task" + r.getId(), e); + listener.onFailure(e); + } + } else { + listener.onFailure(new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND)); + } + }, listener::onFailure)); + } + + public void getModel(String modelId, ActionListener listener) { + getModel(modelId, null, null, listener); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/model/StopWords.java b/common/src/main/java/org/opensearch/ml/common/model/StopWords.java index 19307b398d..e65f82a96c 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/StopWords.java +++ b/common/src/main/java/org/opensearch/ml/common/model/StopWords.java @@ -8,6 +8,7 @@ import lombok.Builder; import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.NonNull; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; @@ -15,6 +16,8 @@ import org.opensearch.core.xcontent.XContentParser; import java.io.IOException; +import java.util.List; +import java.util.Map; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; @@ -33,6 +36,10 @@ public StopWords(String index, String[] sourceFields) { this.sourceFields = sourceFields; } + public StopWords(@NonNull Map params) { + this((String) params.get(INDEX_NAME_FIELD), ((List) params.get(SOURCE_FIELDS_FIELD)).toArray(new String[0])); + } + public StopWords(StreamInput input) throws IOException { index = input.readString(); sourceFields = input.readStringArray(); diff --git a/common/src/test/java/org/opensearch/ml/common/model/GuardrailTests.java b/common/src/test/java/org/opensearch/ml/common/model/GuardrailTests.java deleted file mode 100644 index b6b140d119..0000000000 --- a/common/src/test/java/org/opensearch/ml/common/model/GuardrailTests.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.model; - -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.TestHelper; -import org.opensearch.search.SearchModule; - -import java.io.IOException; -import java.util.Collections; -import java.util.List; - -import static org.junit.Assert.*; - -public class GuardrailTests { - StopWords stopWords; - String[] regex; - - @Before - public void setUp() { - stopWords = new StopWords("test_index", List.of("test_field").toArray(new String[0])); - regex = List.of("regex1").toArray(new String[0]); - } - - @Test - public void writeTo() throws IOException { - Guardrail guardrail = new Guardrail(List.of(stopWords), regex); - BytesStreamOutput output = new BytesStreamOutput(); - guardrail.writeTo(output); - Guardrail guardrail1 = new Guardrail(output.bytes().streamInput()); - - Assert.assertArrayEquals(guardrail.getStopWords().toArray(), guardrail1.getStopWords().toArray()); - Assert.assertArrayEquals(guardrail.getRegex(), guardrail1.getRegex()); - } - - @Test - public void toXContent() throws IOException { - Guardrail guardrail = new Guardrail(List.of(stopWords), regex); - XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); - guardrail.toXContent(builder, ToXContent.EMPTY_PARAMS); - String content = TestHelper.xContentBuilderToString(builder); - - Assert.assertEquals("{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}", content); - } - - @Test - public void parse() throws IOException { - String jsonStr = "{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}"; - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); - parser.nextToken(); - Guardrail guardrail = Guardrail.parse(parser); - - Assert.assertArrayEquals(guardrail.getStopWords().toArray(), List.of(stopWords).toArray()); - Assert.assertArrayEquals(guardrail.getRegex(), regex); - } -} \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/model/GuardrailsTests.java b/common/src/test/java/org/opensearch/ml/common/model/GuardrailsTests.java index dc0c3d116c..a1b589d07c 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/GuardrailsTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/GuardrailsTests.java @@ -22,25 +22,23 @@ import java.util.Collections; import java.util.List; -import static org.junit.Assert.*; - public class GuardrailsTests { StopWords stopWords; String[] regex; - Guardrail inputGuardrail; - Guardrail outputGuardrail; + LocalRegexGuardrail inputLocalRegexGuardrail; + LocalRegexGuardrail outputLocalRegexGuardrail; @Before public void setUp() { stopWords = new StopWords("test_index", List.of("test_field").toArray(new String[0])); regex = List.of("regex1").toArray(new String[0]); - inputGuardrail = new Guardrail(List.of(stopWords), regex); - outputGuardrail = new Guardrail(List.of(stopWords), regex); + inputLocalRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex); + outputLocalRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex); } @Test public void writeTo() throws IOException { - Guardrails guardrails = new Guardrails("test_type", inputGuardrail, outputGuardrail); + Guardrails guardrails = new Guardrails("local_regex", inputLocalRegexGuardrail, outputLocalRegexGuardrail); BytesStreamOutput output = new BytesStreamOutput(); guardrails.writeTo(output); Guardrails guardrails1 = new Guardrails(output.bytes().streamInput()); @@ -52,12 +50,12 @@ public void writeTo() throws IOException { @Test public void toXContent() throws IOException { - Guardrails guardrails = new Guardrails("test_type", inputGuardrail, outputGuardrail); + Guardrails guardrails = new Guardrails("local_regex", inputLocalRegexGuardrail, outputLocalRegexGuardrail); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); guardrails.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - Assert.assertEquals("{\"type\":\"test_type\"," + + Assert.assertEquals("{\"type\":\"local_regex\"," + "\"input_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}," + "\"output_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}}", content); @@ -65,7 +63,7 @@ public void toXContent() throws IOException { @Test public void parse() throws IOException { - String jsonStr = "{\"type\":\"test_type\"," + + String jsonStr = "{\"type\":\"local_regex\"," + "\"input_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}," + "\"output_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}}"; XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, @@ -73,8 +71,8 @@ public void parse() throws IOException { parser.nextToken(); Guardrails guardrails = Guardrails.parse(parser); - Assert.assertEquals(guardrails.getType(), "test_type"); - Assert.assertEquals(guardrails.getInputGuardrail(), inputGuardrail); - Assert.assertEquals(guardrails.getOutputGuardrail(), outputGuardrail); + Assert.assertEquals(guardrails.getType(), "local_regex"); + Assert.assertEquals(guardrails.getInputGuardrail(), inputLocalRegexGuardrail); + Assert.assertEquals(guardrails.getOutputGuardrail(), outputLocalRegexGuardrail); } } \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/model/LocalRegexGuardrailTests.java b/common/src/test/java/org/opensearch/ml/common/model/LocalRegexGuardrailTests.java new file mode 100644 index 0000000000..735ee9e332 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/model/LocalRegexGuardrailTests.java @@ -0,0 +1,270 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.model; + +import org.apache.lucene.search.TotalHits; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.client.Client; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.SearchModule; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.search.profile.SearchProfileShardResults; +import org.opensearch.search.suggest.Suggest; +import org.opensearch.threadpool.ThreadPool; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.regex.Pattern; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +public class LocalRegexGuardrailTests { + NamedXContentRegistry xContentRegistry; + @Mock + Client client; + @Mock + ThreadPool threadPool; + ThreadContext threadContext; + + StopWords stopWords; + String[] regex; + List regexPatterns; + LocalRegexGuardrail localRegexGuardrail; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + xContentRegistry = new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + Settings settings = Settings.builder().build(); + this.threadContext = new ThreadContext(settings); + when(this.client.threadPool()).thenReturn(this.threadPool); + when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); + + stopWords = new StopWords("test_index", List.of("test_field").toArray(new String[0])); + regex = List.of("(.|\n)*stop words(.|\n)*").toArray(new String[0]); + regexPatterns = List.of(Pattern.compile("(.|\n)*stop words(.|\n)*")); + localRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex); + } + + @Test + public void writeTo() throws IOException { + LocalRegexGuardrail localRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex); + BytesStreamOutput output = new BytesStreamOutput(); + localRegexGuardrail.writeTo(output); + LocalRegexGuardrail localRegexGuardrail1 = new LocalRegexGuardrail(output.bytes().streamInput()); + + Assert.assertArrayEquals(localRegexGuardrail.getStopWords().toArray(), localRegexGuardrail1.getStopWords().toArray()); + Assert.assertArrayEquals(localRegexGuardrail.getRegex(), localRegexGuardrail1.getRegex()); + } + + @Test + public void toXContent() throws IOException { + LocalRegexGuardrail localRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + localRegexGuardrail.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + + Assert.assertEquals("{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"(.|\\n)*stop words(.|\\n)*\"]}", content); + } + + @Test + public void parse() throws IOException { + String jsonStr = "{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"(.|\\n)*stop words(.|\\n)*\"]}"; + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); + LocalRegexGuardrail localRegexGuardrail = LocalRegexGuardrail.parse(parser); + + Assert.assertArrayEquals(localRegexGuardrail.getStopWords().toArray(), List.of(stopWords).toArray()); + Assert.assertArrayEquals(localRegexGuardrail.getRegex(), regex); + } + + @Test + public void validateRegexListSuccess() { + String input = "\n\nHuman:hello good words.\n\nAssistant:"; + Boolean res = localRegexGuardrail.validateRegexList(input, regexPatterns); + + Assert.assertTrue(res); + } + + @Test + public void validateRegexListFailed() { + String input = "\n\nHuman:hello stop words.\n\nAssistant:"; + Boolean res = localRegexGuardrail.validateRegexList(input, regexPatterns); + + Assert.assertFalse(res); + } + + @Test + public void validateRegexListNull() { + String input = "\n\nHuman:hello good words.\n\nAssistant:"; + Boolean res = localRegexGuardrail.validateRegexList(input, null); + + Assert.assertTrue(res); + } + + @Test + public void validateRegexListEmpty() { + String input = "\n\nHuman:hello good words.\n\nAssistant:"; + Boolean res = localRegexGuardrail.validateRegexList(input, List.of()); + + Assert.assertTrue(res); + } + + @Test + public void validateRegexSuccess() { + String input = "\n\nHuman:hello good words.\n\nAssistant:"; + Boolean res = localRegexGuardrail.validateRegex(input, regexPatterns.get(0)); + + Assert.assertTrue(res); + } + + @Test + public void validateRegexFailed() { + String input = "\n\nHuman:hello stop words.\n\nAssistant:"; + Boolean res = localRegexGuardrail.validateRegex(input, regexPatterns.get(0)); + + Assert.assertFalse(res); + } + + @Test + public void validateStopWords() throws IOException { + Map> stopWordsIndices = Map.of("test_index", List.of("test_field")); + SearchResponse searchResponse = createSearchResponse(1); + ActionFuture future = createSearchResponseFuture(searchResponse); + when(this.client.search(any())).thenReturn(future); + + Boolean res = localRegexGuardrail.validateStopWords("hello world", stopWordsIndices); + Assert.assertTrue(res); + } + + @Test + public void validateStopWordsNull() { + Boolean res = localRegexGuardrail.validateStopWords("hello world", null); + Assert.assertTrue(res); + } + + @Test + public void validateStopWordsEmpty() { + Boolean res = localRegexGuardrail.validateStopWords("hello world", Map.of()); + Assert.assertTrue(res); + } + + @Test + public void validateStopWordsSingleIndex() throws IOException { + SearchResponse searchResponse = createSearchResponse(1); + ActionFuture future = createSearchResponseFuture(searchResponse); + when(this.client.search(any())).thenReturn(future); + + Boolean res = localRegexGuardrail.validateStopWordsSingleIndex("hello world", "test_index", List.of("test_field")); + Assert.assertTrue(res); + } + + private SearchResponse createSearchResponse(int size) throws IOException { + XContentBuilder content = localRegexGuardrail.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + SearchHit[] hits = new SearchHit[size]; + if (size > 0) { + hits[0] = new SearchHit(0).sourceRef(BytesReference.bytes(content)); + } + return new SearchResponse( + new InternalSearchResponse( + new SearchHits(hits, new TotalHits(size, TotalHits.Relation.EQUAL_TO), 1.0f), + InternalAggregations.EMPTY, + new Suggest(Collections.emptyList()), + new SearchProfileShardResults(Collections.emptyMap()), + false, + false, + 1 + ), + "", + 5, + 5, + 0, + 100, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + } + + private ActionFuture createSearchResponseFuture(SearchResponse searchResponse) { + return new ActionFuture<>() { + @Override + public SearchResponse actionGet() { + return searchResponse; + } + + @Override + public SearchResponse actionGet(String timeout) { + return searchResponse; + } + + @Override + public SearchResponse actionGet(long timeoutMillis) { + return searchResponse; + } + + @Override + public SearchResponse actionGet(long timeout, TimeUnit unit) { + return searchResponse; + } + + @Override + public SearchResponse actionGet(TimeValue timeout) { + return searchResponse; + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return false; + } + + @Override + public boolean isCancelled() { + return false; + } + + @Override + public boolean isDone() { + return false; + } + + @Override + public SearchResponse get() { + return searchResponse; + } + + @Override + public SearchResponse get(long timeout, TimeUnit unit) { + return searchResponse; + } + }; + } +} \ No newline at end of file diff --git a/common/src/test/java/org/opensearch/ml/common/model/MLGuardTests.java b/common/src/test/java/org/opensearch/ml/common/model/MLGuardTests.java index 4af3072c8a..d6ecb2083a 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/MLGuardTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/MLGuardTests.java @@ -5,43 +5,22 @@ package org.opensearch.ml.common.model; -import org.apache.lucene.search.TotalHits; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.client.Client; -import org.opensearch.common.action.ActionFuture; import org.opensearch.common.settings.Settings; -import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.search.SearchHit; -import org.opensearch.search.SearchHits; import org.opensearch.search.SearchModule; -import org.opensearch.search.aggregations.InternalAggregations; -import org.opensearch.search.internal.InternalSearchResponse; -import org.opensearch.search.profile.SearchProfileShardResults; -import org.opensearch.search.suggest.Suggest; import org.opensearch.threadpool.ThreadPool; -import java.io.IOException; import java.util.Collections; import java.util.List; -import java.util.Map; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; import java.util.regex.Pattern; -import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; public class MLGuardTests { @@ -56,8 +35,8 @@ public class MLGuardTests { StopWords stopWords; String[] regex; List regexPatterns; - Guardrail inputGuardrail; - Guardrail outputGuardrail; + LocalRegexGuardrail inputLocalRegexGuardrail; + LocalRegexGuardrail outputLocalRegexGuardrail; Guardrails guardrails; MLGuard mlGuard; @@ -73,9 +52,9 @@ public void setUp() { stopWords = new StopWords("test_index", List.of("test_field").toArray(new String[0])); regex = List.of("(.|\n)*stop words(.|\n)*").toArray(new String[0]); regexPatterns = List.of(Pattern.compile("(.|\n)*stop words(.|\n)*")); - inputGuardrail = new Guardrail(List.of(stopWords), regex); - outputGuardrail = new Guardrail(List.of(stopWords), regex); - guardrails = new Guardrails("test_type", inputGuardrail, outputGuardrail); + inputLocalRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex); + outputLocalRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex); + guardrails = new Guardrails("test_type", inputLocalRegexGuardrail, outputLocalRegexGuardrail); mlGuard = new MLGuard(guardrails, xContentRegistry, client); } @@ -92,182 +71,13 @@ public void validateInitializedStopWordsEmpty() { stopWords = new StopWords(null, null); regex = List.of("(.|\n)*stop words(.|\n)*").toArray(new String[0]); regexPatterns = List.of(Pattern.compile("(.|\n)*stop words(.|\n)*")); - inputGuardrail = new Guardrail(List.of(stopWords), regex); - outputGuardrail = new Guardrail(List.of(stopWords), regex); - guardrails = new Guardrails("test_type", inputGuardrail, outputGuardrail); + inputLocalRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex); + outputLocalRegexGuardrail = new LocalRegexGuardrail(List.of(stopWords), regex); + guardrails = new Guardrails("test_type", inputLocalRegexGuardrail, outputLocalRegexGuardrail); mlGuard = new MLGuard(guardrails, xContentRegistry, client); String input = "\n\nHuman:hello good words.\n\nAssistant:"; Boolean res = mlGuard.validate(input, MLGuard.Type.INPUT); Assert.assertTrue(res); } - - @Test - public void validateOutput() { - String input = "\n\nHuman:hello stop words.\n\nAssistant:"; - Boolean res = mlGuard.validate(input, MLGuard.Type.OUTPUT); - - Assert.assertFalse(res); - } - - @Test - public void validateRegexListSuccess() { - String input = "\n\nHuman:hello good words.\n\nAssistant:"; - Boolean res = mlGuard.validateRegexList(input, regexPatterns); - - Assert.assertTrue(res); - } - - @Test - public void validateRegexListFailed() { - String input = "\n\nHuman:hello stop words.\n\nAssistant:"; - Boolean res = mlGuard.validateRegexList(input, regexPatterns); - - Assert.assertFalse(res); - } - - @Test - public void validateRegexListNull() { - String input = "\n\nHuman:hello good words.\n\nAssistant:"; - Boolean res = mlGuard.validateRegexList(input, null); - - Assert.assertTrue(res); - } - - @Test - public void validateRegexListEmpty() { - String input = "\n\nHuman:hello good words.\n\nAssistant:"; - Boolean res = mlGuard.validateRegexList(input, List.of()); - - Assert.assertTrue(res); - } - - @Test - public void validateRegexSuccess() { - String input = "\n\nHuman:hello good words.\n\nAssistant:"; - Boolean res = mlGuard.validateRegex(input, regexPatterns.get(0)); - - Assert.assertTrue(res); - } - - @Test - public void validateRegexFailed() { - String input = "\n\nHuman:hello stop words.\n\nAssistant:"; - Boolean res = mlGuard.validateRegex(input, regexPatterns.get(0)); - - Assert.assertFalse(res); - } - - @Test - public void validateStopWords() throws IOException { - Map> stopWordsIndices = Map.of("test_index", List.of("test_field")); - SearchResponse searchResponse = createSearchResponse(1); - ActionFuture future = createSearchResponseFuture(searchResponse); - when(this.client.search(any())).thenReturn(future); - - Boolean res = mlGuard.validateStopWords("hello world", stopWordsIndices); - Assert.assertTrue(res); - } - - @Test - public void validateStopWordsNull() { - Boolean res = mlGuard.validateStopWords("hello world", null); - Assert.assertTrue(res); - } - - @Test - public void validateStopWordsEmpty() { - Boolean res = mlGuard.validateStopWords("hello world", Map.of()); - Assert.assertTrue(res); - } - - @Test - public void validateStopWordsSingleIndex() throws IOException { - SearchResponse searchResponse = createSearchResponse(1); - ActionFuture future = createSearchResponseFuture(searchResponse); - when(this.client.search(any())).thenReturn(future); - - Boolean res = mlGuard.validateStopWordsSingleIndex("hello world", "test_index", List.of("test_field")); - Assert.assertTrue(res); - } - - private SearchResponse createSearchResponse(int size) throws IOException { - XContentBuilder content = guardrails.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); - SearchHit[] hits = new SearchHit[size]; - if (size > 0) { - hits[0] = new SearchHit(0).sourceRef(BytesReference.bytes(content)); - } - return new SearchResponse( - new InternalSearchResponse( - new SearchHits(hits, new TotalHits(size, TotalHits.Relation.EQUAL_TO), 1.0f), - InternalAggregations.EMPTY, - new Suggest(Collections.emptyList()), - new SearchProfileShardResults(Collections.emptyMap()), - false, - false, - 1 - ), - "", - 5, - 5, - 0, - 100, - ShardSearchFailure.EMPTY_ARRAY, - SearchResponse.Clusters.EMPTY - ); - } - - private ActionFuture createSearchResponseFuture(SearchResponse searchResponse) { - return new ActionFuture<>() { - @Override - public SearchResponse actionGet() { - return searchResponse; - } - - @Override - public SearchResponse actionGet(String timeout) { - return searchResponse; - } - - @Override - public SearchResponse actionGet(long timeoutMillis) { - return searchResponse; - } - - @Override - public SearchResponse actionGet(long timeout, TimeUnit unit) { - return searchResponse; - } - - @Override - public SearchResponse actionGet(TimeValue timeout) { - return searchResponse; - } - - @Override - public boolean cancel(boolean mayInterruptIfRunning) { - return false; - } - - @Override - public boolean isCancelled() { - return false; - } - - @Override - public boolean isDone() { - return false; - } - - @Override - public SearchResponse get() throws InterruptedException, ExecutionException { - return searchResponse; - } - - @Override - public SearchResponse get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { - return searchResponse; - } - }; - } } \ No newline at end of file