diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorFactory.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorFactory.java index e08835f0bea14..01084ee0b7f8b 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorFactory.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorFactory.java @@ -89,7 +89,7 @@ public Aggregator createInternal(Aggregator parent, boolean collectsFromSingleBu final ScriptedMetricAggContexts.CombineScript combineScript = this.combineScript.newInstance( mergeParams(aggParams, combineScriptParams), aggState); - final Script reduceScript = deepCopyScript(this.reduceScript, context); + final Script reduceScript = deepCopyScript(this.reduceScript, context, aggParams); if (initScript != null) { initScript.execute(); CollectionUtils.ensureNoSelfReferences(aggState, "Scripted metric aggs init script"); @@ -99,12 +99,9 @@ public Aggregator createInternal(Aggregator parent, boolean collectsFromSingleBu pipelineAggregators, metaData); } - private static Script deepCopyScript(Script script, SearchContext context) { + private static Script deepCopyScript(Script script, SearchContext context, Map aggParams) { if (script != null) { - Map params = script.getParams(); - if (params != null) { - params = deepCopyParams(params, context); - } + Map params = mergeParams(aggParams, deepCopyParams(script.getParams(), context)); return new Script(script.getType(), script.getLang(), script.getIdOrCode(), params); } else { return null; diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorTests.java index 05115a03e300f..5f74937f6610b 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorTests.java @@ -71,7 +71,9 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase { private static final Script MAP_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "mapScriptParams", Collections.singletonMap("itemValue", 12)); private static final Script COMBINE_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScriptParams", - Collections.singletonMap("divisor", 4)); + Collections.singletonMap("multiplier", 4)); + private static final Script REDUCE_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "reduceScriptParams", + Collections.singletonMap("additional", 2)); private static final String CONFLICTING_PARAM_NAME = "initialValue"; private static final Script INIT_SCRIPT_SELF_REF = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "initScriptSelfRef", @@ -140,9 +142,14 @@ public static void initMockScripts() { }); SCRIPTS.put("combineScriptParams", params -> { Map state = (Map) params.get("state"); - int divisor = ((Integer) params.get("divisor")); - return ((List) state.get("collector")).stream().mapToInt(Integer::intValue).map(i -> i / divisor).sum(); + int multiplier = ((Integer) params.get("multiplier")); + return ((List) state.get("collector")).stream().mapToInt(Integer::intValue).map(i -> i * multiplier).sum(); }); + SCRIPTS.put("reduceScriptParams", params -> + ((List)params.get("states")).stream().mapToInt(i -> (int)i).sum() + + (int)params.get("aggs_param") + (int)params.get("additional") - + ((List)params.get("states")).size()*24*4 + ); SCRIPTS.put("initScriptSelfRef", params -> { Map state = (Map) params.get("state"); @@ -279,7 +286,33 @@ public void testScriptParamsPassedThrough() throws IOException { ScriptedMetric scriptedMetric = search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder); // The result value depends on the script params. - assertEquals(306, scriptedMetric.aggregation()); + assertEquals(4896, scriptedMetric.aggregation()); + } + } + } + + public void testAggParamsPassedToReduceScript() throws IOException { + MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME, SCRIPTS, Collections.emptyMap()); + Map engines = Collections.singletonMap(scriptEngine.getType(), scriptEngine); + ScriptService scriptService = new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS); + + try (Directory directory = newDirectory()) { + try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) { + for (int i = 0; i < 100; i++) { + indexWriter.addDocument(singleton(new SortedNumericDocValuesField("number", i))); + } + } + + try (IndexReader indexReader = DirectoryReader.open(directory)) { + ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME); + aggregationBuilder.params(Collections.singletonMap("aggs_param", 1)) + .initScript(INIT_SCRIPT_PARAMS).mapScript(MAP_SCRIPT_PARAMS) + .combineScript(COMBINE_SCRIPT_PARAMS).reduceScript(REDUCE_SCRIPT_PARAMS); + ScriptedMetric scriptedMetric = searchAndReduce( + newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder, 0, scriptService); + + // The result value depends on the script params. + assertEquals(4803, scriptedMetric.aggregation()); } } }