Skip to content

Commit

Permalink
Aggregations Refactor: Refactor Scripted Metric Aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
colings86 committed Nov 24, 2015
1 parent 21556f9 commit 510848f
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 9 deletions.
20 changes: 16 additions & 4 deletions core/src/main/java/org/elasticsearch/script/Script.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,24 @@

import java.io.IOException;
import java.util.Map;
import java.util.function.Supplier;

/**
* Script holds all the parameters necessary to compile or find in cache and then execute a script.
*/
public class Script implements ToXContent, Streamable {

/**
* A {@link Supplier} implementation for use when reading a {@link Script}
* using {@link StreamInput#readOptionalStreamable(Supplier)}
*/
public static final Supplier<Script> SUPPLIER = new Supplier<Script>() {

@Override
public Script get() {
return new Script();
}
};
public static final ScriptType DEFAULT_TYPE = ScriptType.INLINE;
private static final ScriptParser PARSER = new ScriptParser();

Expand Down Expand Up @@ -74,7 +86,7 @@ protected Script(String script, String lang) {

/**
* Constructor for Script.
*
*
* @param script
* The cache key of the script to be compiled/executed. For
* inline scripts this is the actual script source code. For
Expand Down Expand Up @@ -112,7 +124,7 @@ public String getScript() {

/**
* Method for getting the type.
*
*
* @return The type of script -- inline, indexed, or file.
*/
public ScriptType getType() {
Expand All @@ -121,7 +133,7 @@ public ScriptType getType() {

/**
* Method for getting language.
*
*
* @return The language of the script to be compiled/executed.
*/
public String getLang() {
Expand All @@ -130,7 +142,7 @@ public String getLang() {

/**
* Method for getting the parameters.
*
*
* @return The map of parameters the script will be executed with.
*/
public Map<String, Object> getParams() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
package org.elasticsearch.search.aggregations.metrics.scripted;

import org.apache.lucene.index.LeafReaderContext;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.script.ExecutableScript;
import org.elasticsearch.script.LeafSearchScript;
import org.elasticsearch.script.Script;
Expand All @@ -43,6 +46,7 @@
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;

public class ScriptedMetricAggregator extends MetricsAggregator {

Expand Down Expand Up @@ -113,13 +117,43 @@ public static class Factory extends AggregatorFactory {
private Script reduceScript;
private Map<String, Object> params;

public Factory(String name, Script initScript, Script mapScript, Script combineScript, Script reduceScript,
Map<String, Object> params) {
public Factory(String name) {
super(name, InternalScriptedMetric.TYPE);
}

/**
* Set the <tt>init</tt> script.
*/
public void initScript(Script initScript) {
this.initScript = initScript;
}

/**
* Set the <tt>map</tt> script.
*/
public void mapScript(Script mapScript) {
this.mapScript = mapScript;
}

/**
* Set the <tt>combine</tt> script.
*/
public void combineScript(Script combineScript) {
this.combineScript = combineScript;
}

/**
* Set the <tt>reduce</tt> script.
*/
public void reduceScript(Script reduceScript) {
this.reduceScript = reduceScript;
}

/**
* Set parameters that will be available in the <tt>init</tt>,
* <tt>map</tt> and <tt>combine</tt> phases.
*/
public void params(Map<String, Object> params) {
this.params = params;
}

Expand Down Expand Up @@ -188,6 +222,73 @@ private static <T> T deepCopyParams(T original, SearchContext context) {
return clone;
}

@Override
protected XContentBuilder internalXContent(XContentBuilder builder, Params builderParams) throws IOException {
builder.startObject();
if (initScript != null) {
builder.field(ScriptedMetricParser.INIT_SCRIPT_FIELD.getPreferredName(), initScript);
}

if (mapScript != null) {
builder.field(ScriptedMetricParser.MAP_SCRIPT_FIELD.getPreferredName(), mapScript);
}

if (combineScript != null) {
builder.field(ScriptedMetricParser.COMBINE_SCRIPT_FIELD.getPreferredName(), combineScript);
}

if (reduceScript != null) {
builder.field(ScriptedMetricParser.REDUCE_SCRIPT_FIELD.getPreferredName(), reduceScript);
}
if (params != null) {
builder.field(ScriptedMetricParser.PARAMS_FIELD.getPreferredName());
builder.map(params);
}
builder.endObject();
return builder;
}

@Override
protected AggregatorFactory doReadFrom(String name, StreamInput in) throws IOException {
Factory factory = new Factory(name);
factory.initScript = in.readOptionalStreamable(Script.SUPPLIER);
factory.mapScript = in.readOptionalStreamable(Script.SUPPLIER);
factory.combineScript = in.readOptionalStreamable(Script.SUPPLIER);
factory.reduceScript = in.readOptionalStreamable(Script.SUPPLIER);
if (in.readBoolean()) {
factory.params = in.readMap();
}
return factory;
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeOptionalStreamable(initScript);
out.writeOptionalStreamable(mapScript);
out.writeOptionalStreamable(combineScript);
out.writeOptionalStreamable(reduceScript);
boolean hasParams = params != null;
out.writeBoolean(hasParams);
if (hasParams) {
out.writeMap(params);
}
}

@Override
protected int doHashCode() {
return Objects.hash(initScript, mapScript, combineScript, reduceScript, params);
}

@Override
protected boolean doEquals(Object obj) {
Factory other = (Factory) obj;
return Objects.equals(initScript, other.initScript)
&& Objects.equals(mapScript, other.mapScript)
&& Objects.equals(combineScript, other.combineScript)
&& Objects.equals(reduceScript, other.reduceScript)
&& Objects.equals(params, other.params);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,19 @@ public AggregatorFactory parse(String aggregationName, XContentParser parser, Se
if (mapScript == null) {
throw new SearchParseException(context, "map_script field is required in [" + aggregationName + "].", parser.getTokenLocation());
}
return new ScriptedMetricAggregator.Factory(aggregationName, initScript, mapScript, combineScript, reduceScript, params);

ScriptedMetricAggregator.Factory factory = new ScriptedMetricAggregator.Factory(aggregationName);
factory.initScript(initScript);
factory.mapScript(mapScript);
factory.combineScript(combineScript);
factory.reduceScript(reduceScript);
factory.params(params);
return factory;
}

// NORELEASE implement this method when refactoring this aggregation
@Override
public AggregatorFactory[] getFactoryPrototypes() {
return null;
return new AggregatorFactory[] { new ScriptedMetricAggregator.Factory(null) };
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.search.aggregations.metrics;

import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptService.ScriptType;
import org.elasticsearch.search.aggregations.BaseAggregationTestCase;
import org.elasticsearch.search.aggregations.metrics.scripted.ScriptedMetricAggregator;
import org.elasticsearch.search.aggregations.metrics.scripted.ScriptedMetricAggregator.Factory;

import java.util.HashMap;
import java.util.Map;

public class ScriptedMetricTests extends BaseAggregationTestCase<ScriptedMetricAggregator.Factory> {

@Override
protected Factory createTestAggregatorFactory() {
Factory factory = new Factory(randomAsciiOfLengthBetween(1, 20));
if (randomBoolean()) {
factory.initScript(randomScript("initScript"));
}
factory.mapScript(randomScript("mapScript"));
if (randomBoolean()) {
factory.combineScript(randomScript("combineScript"));
}
if (randomBoolean()) {
factory.reduceScript(randomScript("reduceScript"));
}
if (randomBoolean()) {
Map<String, Object> params = new HashMap<String, Object>();
params.put("foo", "bar");
factory.params(params);
}
return factory;
}

private Script randomScript(String script) {
if (randomBoolean()) {
return new Script(script);
} else {
return new Script(script, randomFrom(ScriptType.values()), randomFrom("my_lang", null), null);
}
}

}

0 comments on commit 510848f

Please sign in to comment.