Skip to content

Commit

Permalink
Support more parameters for AD and KMEANS command, and update related…
Browse files Browse the repository at this point in the history
… documentation (#515)

Signed-off-by: jackieyanghan <[email protected]>
  • Loading branch information
jackiehanyang authored and vamsi-amazon committed Apr 19, 2022
1 parent 74e53e9 commit 8042c65
Show file tree
Hide file tree
Showing 17 changed files with 272 additions and 134 deletions.
4 changes: 2 additions & 2 deletions core/src/main/java/org/opensearch/sql/analysis/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ public LogicalPlan visitValues(Values node, AnalysisContext context) {
@Override
public LogicalPlan visitKmeans(Kmeans node, AnalysisContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
List<Argument> options = node.getOptions();
java.util.Map<String, Literal> options = node.getArguments();

TypeEnvironment currentEnv = context.peek();
currentEnv.define(new Symbol(Namespace.FIELD_NAME, "ClusterID"), ExprCoreType.INTEGER);
Expand All @@ -430,7 +430,7 @@ public LogicalPlan visitAD(AD node, AnalysisContext context) {
TypeEnvironment currentEnv = context.peek();

currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_SCORE), ExprCoreType.DOUBLE);
if (Objects.isNull(node.getArguments().get(TIME_FIELD).getValue())) {
if (Objects.isNull(node.getArguments().get(TIME_FIELD))) {
currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_ANOMALOUS), ExprCoreType.BOOLEAN);
} else {
currentEnv.define(new Symbol(Namespace.FIELD_NAME, RCF_ANOMALY_GRADE), ExprCoreType.DOUBLE);
Expand Down
5 changes: 3 additions & 2 deletions core/src/main/java/org/opensearch/sql/ast/tree/Kmeans.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@

import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Map;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.Argument;
import org.opensearch.sql.ast.expression.Literal;

@Getter
@Setter
Expand All @@ -26,7 +27,7 @@
public class Kmeans extends UnresolvedPlan {
private UnresolvedPlan child;

private final List<Argument> options;
private final Map<String, Literal> arguments;

@Override
public UnresolvedPlan attach(UnresolvedPlan child) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package org.opensearch.sql.planner.logical;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import org.opensearch.sql.ast.expression.Argument;
import org.opensearch.sql.ast.expression.Literal;

/**
* ml-commons logical plan.
Expand All @@ -16,7 +16,7 @@
public class LogicalMLCommons extends LogicalPlan {
private final String algorithm;

private final List<Argument> arguments;
private final Map<String, Literal> arguments;

/**
* Constructor of LogicalMLCommons.
Expand All @@ -25,7 +25,7 @@ public class LogicalMLCommons extends LogicalPlan {
* @param arguments arguments of the algorithm
*/
public LogicalMLCommons(LogicalPlan child, String algorithm,
List<Argument> arguments) {
Map<String, Literal> arguments) {
super(Collections.singletonList(child));
this.algorithm = algorithm;
this.arguments = arguments;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,26 @@

public class MLCommonsConstants {

// AD constants
public static final String NUMBER_OF_TREES = "number_of_trees";
public static final String SHINGLE_SIZE = "shingle_size";
public static final String SAMPLE_SIZE = "sample_size";
public static final String OUTPUT_AFTER = "output_after";
public static final String TIME_DECAY = "time_decay";
public static final String ANOMALY_RATE = "anomaly_rate";
public static final String TIME_FIELD = "time_field";
public static final String DATE_FORMAT = "date_format";
public static final String TIME_ZONE = "time_zone";
public static final String TRAINING_DATA_SIZE = "training_data_size";
public static final String ANOMALY_SCORE_THRESHOLD = "anomaly_score_threshold";

public static final String RCF_SCORE = "score";
public static final String RCF_ANOMALOUS = "anomalous";
public static final String RCF_ANOMALY_GRADE = "anomaly_grade";
public static final String RCF_TIMESTAMP = "timestamp";

// KMEANS constants
public static final String CENTROIDS = "centroids";
public static final String ITERATIONS = "iterations";
public static final String DISTANCE_TYPE = "distance_type";
}
13 changes: 7 additions & 6 deletions core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -701,12 +701,15 @@ public void parse_relation() {

@Test
public void kmeanns_relation() {
Map<String, Literal> argumentMap = new HashMap<String, Literal>() {{
put("centroids", new Literal(3, DataType.INTEGER));
put("iterations", new Literal(2, DataType.INTEGER));
put("distance_type", new Literal("COSINE", DataType.STRING));
}};
assertAnalyzeEqual(
new LogicalMLCommons(LogicalPlanDSL.relation("schema"),
"kmeans",
AstDSL.exprList(AstDSL.argument("k", AstDSL.intLiteral(3)))),
new Kmeans(AstDSL.relation("schema"),
AstDSL.exprList(AstDSL.argument("k", AstDSL.intLiteral(3))))
"kmeans", argumentMap),
new Kmeans(AstDSL.relation("schema"), argumentMap)
);
}

Expand All @@ -715,8 +718,6 @@ public void ad_batchRCF_relation() {
Map<String, Literal> argumentMap =
new HashMap<String, Literal>() {{
put("shingle_size", new Literal(8, DataType.INTEGER));
put("time_decay", new Literal(0.0001, DataType.DOUBLE));
put("time_field", new Literal(null, DataType.STRING));
}};
assertAnalyzeEqual(
new LogicalAD(LogicalPlanDSL.relation("schema"), argumentMap),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,11 @@ public void testAbstractPlanNodeVisitorShouldReturnNull() {

LogicalPlan mlCommons = new LogicalMLCommons(LogicalPlanDSL.relation("schema"),
"kmeans",
AstDSL.exprList(AstDSL.argument("k", AstDSL.intLiteral(3))));
ImmutableMap.<String, Literal>builder()
.put("centroids", new Literal(3, DataType.INTEGER))
.put("iterations", new Literal(3, DataType.DOUBLE))
.put("distance_type", new Literal(null, DataType.STRING))
.build());
assertNull(mlCommons.accept(new LogicalPlanNodeVisitor<Integer, Object>() {
}, null));

Expand Down
24 changes: 16 additions & 8 deletions docs/user/ppl/cmd/ad.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,28 @@ Description

Fixed In Time RCF For Time-series Data Command Syntax
=====================================================
ad <shingle_size> <time_decay> <time_field>
ad <number_of_trees> <shingle_size> <sample_size> <output_after> <time_decay> <anomaly_rate> <time_field> <date_format> <time_zone>

* shingle_size: optional. A shingle is a consecutive sequence of the most recent records. The default value is 8.
* time_decay: optional. It specifies how much of the recent past to consider when computing an anomaly score. The default value is 0.001.
* time_field: mandatory. It specifies the time filed for RCF to use as time-series data.
* number_of_trees(integer): optional. Number of trees in the forest. The default value is 30.
* shingle_size(integer): optional. A shingle is a consecutive sequence of the most recent records. The default value is 8.
* sample_size(integer): optional. The sample size used by stream samplers in this forest. The default value is 256.
* output_after(integer): optional. The number of points required by stream samplers before results are returned. The default value is 32.
* time_decay(double): optional. The decay factor used by stream samplers in this forest. The default value is 0.0001.
* anomaly_rate(double): optional. The anomaly rate. The default value is 0.005.
* time_field(string): mandatory. It specifies the time filed for RCF to use as time-series data.
* date_format(string): optional. It's used for formatting time_field field. The default formatting is "yyyy-MM-dd HH:mm:ss".
* time_zone(string): optional. It's used for setting time zone for time_field filed. The default time zone is UTC.


Batch RCF for Non-time-series Data Command Syntax
=================================================
ad <shingle_size> <time_decay>

* shingle_size: optional. A shingle is a consecutive sequence of the most recent records. The default value is 8.
* time_decay: optional. It specifies how much of the recent past to consider when computing an anomaly score. The default value is 0.001.
ad <number_of_trees> <sample_size> <output_after> <training_data_size> <anomaly_score_threshold>

* number_of_trees(integer): optional. Number of trees in the forest. The default value is 30.
* sample_size(integer): optional. Number of random samples given to each tree from the training data set. The default value is 256.
* output_after(integer): optional. The number of points required by stream samplers before results are returned. The default value is 32.
* training_data_size(integer): optional. The default value is the size of your training data set.
* anomaly_score_threshold(double): optional. The threshold of anomaly score. The default value is 1.0.

Example1: Detecting events in New York City from taxi ridership data with time-series data
==========================================================================================
Expand Down
8 changes: 5 additions & 3 deletions docs/user/ppl/cmd/kmeans.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ Description

Syntax
======
kmeans <cluster-number>
kmeans <centroids> <iterations> <distance_type>

* cluster-number: mandatory. The number of clusters you want to group your data points into.
* centroids: optional. The number of clusters you want to group your data points into. The default value is 2.
* iterations: optional. Number of iterations. The default value is 10.
* distance_type: optional. The distance type can be COSINE, L1, or EUCLIDEAN, The default type is EUCLIDEAN.


Example: Clustering of Iris Dataset
Expand All @@ -28,7 +30,7 @@ The example shows how to classify three Iris species (Iris setosa, Iris virginic

PPL query::

os> source=iris_data | fields sepal_length_in_cm, sepal_width_in_cm, petal_length_in_cm, petal_width_in_cm | kmeans 3
os> source=iris_data | fields sepal_length_in_cm, sepal_width_in_cm, petal_length_in_cm, petal_width_in_cm | kmeans centroids=3
+--------------------+-------------------+--------------------+-------------------+-----------+
| sepal_length_in_cm | sepal_width_in_cm | petal_length_in_cm | petal_width_in_cm | ClusterID |
|--------------------+-------------------+--------------------+-------------------+-----------|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,17 @@

package org.opensearch.sql.opensearch.planner.physical;

import static org.opensearch.sql.utils.MLCommonsConstants.ANOMALY_RATE;
import static org.opensearch.sql.utils.MLCommonsConstants.ANOMALY_SCORE_THRESHOLD;
import static org.opensearch.sql.utils.MLCommonsConstants.DATE_FORMAT;
import static org.opensearch.sql.utils.MLCommonsConstants.NUMBER_OF_TREES;
import static org.opensearch.sql.utils.MLCommonsConstants.OUTPUT_AFTER;
import static org.opensearch.sql.utils.MLCommonsConstants.SAMPLE_SIZE;
import static org.opensearch.sql.utils.MLCommonsConstants.SHINGLE_SIZE;
import static org.opensearch.sql.utils.MLCommonsConstants.TIME_DECAY;
import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD;
import static org.opensearch.sql.utils.MLCommonsConstants.TIME_ZONE;
import static org.opensearch.sql.utils.MLCommonsConstants.TRAINING_DATA_SIZE;

import java.util.Collections;
import java.util.Iterator;
Expand Down Expand Up @@ -97,18 +105,55 @@ public List<PhysicalPlan> getChild() {
}

protected MLAlgoParams convertArgumentToMLParameter(Map<String, Literal> arguments) {
if (arguments.get(TIME_FIELD).getValue() == null) {
if (arguments.get(TIME_FIELD) == null) {
rcfType = FunctionName.BATCH_RCF;
return BatchRCFParams.builder()
.shingleSize((Integer) arguments.get(SHINGLE_SIZE).getValue())
.numberOfTrees(arguments.containsKey(NUMBER_OF_TREES)
? ((Integer) arguments.get(NUMBER_OF_TREES).getValue())
: null)
.sampleSize(arguments.containsKey(SAMPLE_SIZE)
? ((Integer) arguments.get(SAMPLE_SIZE).getValue())
: null)
.outputAfter(arguments.containsKey(OUTPUT_AFTER)
? ((Integer) arguments.get(OUTPUT_AFTER).getValue())
: null)
.trainingDataSize(arguments.containsKey(TRAINING_DATA_SIZE)
? ((Integer) arguments.get(TRAINING_DATA_SIZE).getValue())
: null)
.anomalyScoreThreshold(arguments.containsKey(ANOMALY_SCORE_THRESHOLD)
? ((Double) arguments.get(ANOMALY_SCORE_THRESHOLD).getValue())
: null)
.build();
}
rcfType = FunctionName.FIT_RCF;
return FitRCFParams.builder()
.shingleSize((Integer) arguments.get(SHINGLE_SIZE).getValue())
.timeDecay((Double) arguments.get(TIME_DECAY).getValue())
.timeField((String) arguments.get(TIME_FIELD).getValue())
.dateFormat("yyyy-MM-dd HH:mm:ss")
.numberOfTrees(arguments.containsKey(NUMBER_OF_TREES)
? ((Integer) arguments.get(NUMBER_OF_TREES).getValue())
: null)
.shingleSize(arguments.containsKey(SHINGLE_SIZE)
? ((Integer) arguments.get(SHINGLE_SIZE).getValue())
: null)
.sampleSize(arguments.containsKey(SAMPLE_SIZE)
? ((Integer) arguments.get(SAMPLE_SIZE).getValue())
: null)
.outputAfter(arguments.containsKey(OUTPUT_AFTER)
? ((Integer) arguments.get(OUTPUT_AFTER).getValue())
: null)
.timeDecay(arguments.containsKey(TIME_DECAY)
? ((Double) arguments.get(TIME_DECAY).getValue())
: null)
.anomalyRate(arguments.containsKey(ANOMALY_RATE)
? ((Double) arguments.get(ANOMALY_RATE).getValue())
: null)
.timeField(arguments.containsKey(TIME_FIELD)
? ((String) arguments.get(TIME_FIELD).getValue())
: null)
.dateFormat(arguments.containsKey(DATE_FORMAT)
? ((String) arguments.get(DATE_FORMAT).getValue())
: "yyyy-MM-dd HH:mm:ss")
.timeZone(arguments.containsKey(TIME_ZONE)
? ((String) arguments.get(TIME_ZONE).getValue())
: null)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
package org.opensearch.sql.opensearch.planner.physical;

import static org.opensearch.ml.common.parameter.FunctionName.KMEANS;
import static org.opensearch.sql.utils.MLCommonsConstants.CENTROIDS;
import static org.opensearch.sql.utils.MLCommonsConstants.DISTANCE_TYPE;
import static org.opensearch.sql.utils.MLCommonsConstants.ITERATIONS;

import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
Expand All @@ -21,6 +25,7 @@
import org.opensearch.ml.common.parameter.MLAlgoParams;
import org.opensearch.ml.common.parameter.MLPredictionOutput;
import org.opensearch.sql.ast.expression.Argument;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.planner.physical.PhysicalPlan;
import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor;
Expand All @@ -39,7 +44,7 @@ public class MLCommonsOperator extends MLCommonsOperatorActions {
private final String algorithm;

@Getter
private final List<Argument> arguments;
private final Map<String, Literal> arguments;

@Getter
private final NodeClient nodeClient;
Expand All @@ -51,7 +56,7 @@ public class MLCommonsOperator extends MLCommonsOperatorActions {
public void open() {
super.open();
DataFrame inputDataFrame = generateInputDataset(input);
MLAlgoParams mlAlgoParams = convertArgumentToMLParameter(arguments.get(0), algorithm);
MLAlgoParams mlAlgoParams = convertArgumentToMLParameter(arguments, algorithm);
MLPredictionOutput predictionResult =
getMLPredictionResult(FunctionName.valueOf(algorithm.toUpperCase()),
mlAlgoParams, inputDataFrame, nodeClient);
Expand Down Expand Up @@ -91,15 +96,24 @@ public List<PhysicalPlan> getChild() {
return Collections.singletonList(input);
}

protected MLAlgoParams convertArgumentToMLParameter(Argument argument, String algorithm) {
protected MLAlgoParams convertArgumentToMLParameter(Map<String, Literal> arguments,
String algorithm) {
switch (FunctionName.valueOf(algorithm.toUpperCase())) {
case KMEANS:
if (argument.getValue().getValue() instanceof Number) {
return KMeansParams.builder().centroids((Integer) argument.getValue().getValue()).build();
} else {
throw new IllegalArgumentException("unsupported Kmeans argument type:"
+ argument.getValue().getType());
}
return KMeansParams.builder()
.centroids(arguments.containsKey(CENTROIDS)
? ((Integer) arguments.get(CENTROIDS).getValue())
: null)
.iterations(arguments.containsKey(ITERATIONS)
? ((Integer) arguments.get(ITERATIONS).getValue())
: null)
.distanceType(arguments.containsKey(DISTANCE_TYPE)
? (arguments.get(DISTANCE_TYPE).getValue() != null
? KMeansParams.DistanceType.valueOf((
(String) arguments.get(DISTANCE_TYPE).getValue()).toUpperCase())
: null)
: null)
.build();
default:
// TODO: update available algorithms in the message when adding a new case
throw new IllegalArgumentException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,13 @@ public void testVisitMlCommons() {
NodeClient nodeClient = mock(NodeClient.class);
MLCommonsOperator mlCommonsOperator =
new MLCommonsOperator(
values(emptyList()),
"kmeans",
AstDSL.exprList(AstDSL.argument("k1", AstDSL.intLiteral(3))),
nodeClient
);
values(emptyList()), "kmeans",
new HashMap<String, Literal>() {{
put("centroids", new Literal(3, DataType.INTEGER));
put("iterations", new Literal(2, DataType.INTEGER));
put("distance_type", new Literal(null, DataType.STRING));
}
}, nodeClient);

assertEquals(executionProtector.doProtect(mlCommonsOperator),
executionProtector.visitMLCommons(mlCommonsOperator, null));
Expand Down
Loading

0 comments on commit 8042c65

Please sign in to comment.