Skip to content

Commit

Permalink
Add implicit this for class binding in Painless (#40285)
Browse files Browse the repository at this point in the history
This change allows class bindings to add as their first argument, the base script 
class. The this reference to the base script class will be implicitly passed into a 
class binding as the first constructor argument upon initialization when 
specified as the first argument in whitelist entry for the class binding. This 
allows a class binding access to additional information added to the base script 
class such as more information about the current document or current shard. 
One extra requirement for this to work is the appropriate script base class 
must be whitelisted (should be empty).
  • Loading branch information
jdconrad authored Mar 22, 2019
1 parent 2c825fd commit 8da56dc
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 51 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public static Locals newLocalScope(Locals currentScope) {
*/
public static Locals newLambdaScope(Locals programScope, String name, Class<?> returnType, List<Parameter> parameters,
int captureCount, int maxLoopCounter) {
Locals locals = new Locals(programScope, programScope.painlessLookup, returnType, KEYWORDS);
Locals locals = new Locals(programScope, programScope.painlessLookup, programScope.baseClass, returnType, KEYWORDS);
locals.methods = programScope.methods;
List<Class<?>> typeParameters = parameters.stream().map(parameter -> typeToJavaType(parameter.clazz)).collect(Collectors.toList());
locals.methods.put(buildLocalMethodKey(name, parameters.size()), new LocalMethod(name, returnType, typeParameters,
Expand All @@ -113,7 +113,7 @@ public static Locals newLambdaScope(Locals programScope, String name, Class<?> r

/** Creates a new function scope inside the current scope */
public static Locals newFunctionScope(Locals programScope, Class<?> returnType, List<Parameter> parameters, int maxLoopCounter) {
Locals locals = new Locals(programScope, programScope.painlessLookup, returnType, KEYWORDS);
Locals locals = new Locals(programScope, programScope.painlessLookup, programScope.baseClass, returnType, KEYWORDS);
locals.methods = programScope.methods;
for (Parameter parameter : parameters) {
locals.addVariable(parameter.location, parameter.clazz, parameter.name, false);
Expand All @@ -127,8 +127,8 @@ public static Locals newFunctionScope(Locals programScope, Class<?> returnType,

/** Creates a new main method scope */
public static Locals newMainMethodScope(ScriptClassInfo scriptClassInfo, Locals programScope, int maxLoopCounter) {
Locals locals = new Locals(
programScope, programScope.painlessLookup, scriptClassInfo.getExecuteMethodReturnType(), KEYWORDS);
Locals locals = new Locals(programScope, programScope.painlessLookup,
scriptClassInfo.getBaseClass(), scriptClassInfo.getExecuteMethodReturnType(), KEYWORDS);
locals.methods = programScope.methods;
// This reference. Internal use only.
locals.defineVariable(null, Object.class, THIS, true);
Expand All @@ -146,8 +146,8 @@ public static Locals newMainMethodScope(ScriptClassInfo scriptClassInfo, Locals
}

/** Creates a new program scope: the list of methods. It is the parent for all methods */
public static Locals newProgramScope(PainlessLookup painlessLookup, Collection<LocalMethod> methods) {
Locals locals = new Locals(null, painlessLookup, null, null);
public static Locals newProgramScope(ScriptClassInfo scriptClassInfo, PainlessLookup painlessLookup, Collection<LocalMethod> methods) {
Locals locals = new Locals(null, painlessLookup, scriptClassInfo.getBaseClass(), null, null);
locals.methods = new HashMap<>();
for (LocalMethod method : methods) {
locals.addMethod(method);
Expand Down Expand Up @@ -214,10 +214,17 @@ public PainlessLookup getPainlessLookup() {
return painlessLookup;
}

/** Base class for the compiled script. */
public Class<?> getBaseClass() {
return baseClass;
}

///// private impl

/** Whitelist against which this script is being compiled. */
private final PainlessLookup painlessLookup;
/** Base class for the compiled script. */
private final Class<?> baseClass;
// parent scope
private final Locals parent;
// return type of this scope
Expand All @@ -235,15 +242,16 @@ public PainlessLookup getPainlessLookup() {
* Create a new Locals
*/
private Locals(Locals parent) {
this(parent, parent.painlessLookup, parent.returnType, parent.keywords);
this(parent, parent.painlessLookup, parent.baseClass, parent.returnType, parent.keywords);
}

/**
* Create a new Locals with specified return type
*/
private Locals(Locals parent, PainlessLookup painlessLookup, Class<?> returnType, Set<String> keywords) {
private Locals(Locals parent, PainlessLookup painlessLookup, Class<?> baseClass, Class<?> returnType, Set<String> keywords) {
this.parent = parent;
this.painlessLookup = painlessLookup;
this.baseClass = baseClass;
this.returnType = returnType;
this.keywords = keywords;
if (parent == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public final class ECallLocal extends AExpression {
private LocalMethod localMethod = null;
private PainlessMethod importedMethod = null;
private PainlessClassBinding classBinding = null;
private int classBindingOffset = 0;
private PainlessInstanceBinding instanceBinding = null;

public ECallLocal(Location location, String name, List<AExpression> arguments) {
Expand All @@ -75,12 +76,37 @@ void analyze(Locals locals) {
if (importedMethod == null) {
classBinding = locals.getPainlessLookup().lookupPainlessClassBinding(name, arguments.size());

// check to see if this class binding requires an implicit this reference
if (classBinding != null && classBinding.typeParameters.isEmpty() == false &&
classBinding.typeParameters.get(0) == locals.getBaseClass()) {
classBinding = null;
}

if (classBinding == null) {
instanceBinding = locals.getPainlessLookup().lookupPainlessInstanceBinding(name, arguments.size());
// This extra check looks for a possible match where the class binding requires an implicit this
// reference. This is a temporary solution to allow the class binding access to data from the
// base script class without need for a user to add additional arguments. A long term solution
// will likely involve adding a class instance binding where any instance can have a class binding
// as part of its API. However, the situation at run-time is difficult and will modifications that
// are a substantial change if even possible to do.
classBinding = locals.getPainlessLookup().lookupPainlessClassBinding(name, arguments.size() + 1);

if (classBinding != null) {
if (classBinding.typeParameters.isEmpty() == false &&
classBinding.typeParameters.get(0) == locals.getBaseClass()) {
classBindingOffset = 1;
} else {
classBinding = null;
}
}

if (instanceBinding == null) {
throw createError(
new IllegalArgumentException("Unknown call [" + name + "] with [" + arguments.size() + "] arguments."));
if (classBinding == null) {
instanceBinding = locals.getPainlessLookup().lookupPainlessInstanceBinding(name, arguments.size());

if (instanceBinding == null) {
throw createError(new IllegalArgumentException(
"Unknown call [" + name + "] with [" + arguments.size() + "] arguments."));
}
}
}
}
Expand All @@ -104,10 +130,13 @@ void analyze(Locals locals) {
throw new IllegalStateException("Illegal tree structure.");
}

// if the class binding is using an implicit this reference then the arguments counted must
// be incremented by 1 as the this reference will not be part of the arguments passed into
// the class binding call
for (int argument = 0; argument < arguments.size(); ++argument) {
AExpression expression = arguments.get(argument);

expression.expected = typeParameters.get(argument);
expression.expected = typeParameters.get(argument + classBindingOffset);
expression.internal = true;
expression.analyze(locals);
arguments.set(argument, expression.cast(locals));
Expand Down Expand Up @@ -136,7 +165,7 @@ void write(MethodWriter writer, Globals globals) {
} else if (classBinding != null) {
String name = globals.addClassBinding(classBinding.javaConstructor.getDeclaringClass());
Type type = Type.getType(classBinding.javaConstructor.getDeclaringClass());
int javaConstructorParameterCount = classBinding.javaConstructor.getParameterCount();
int javaConstructorParameterCount = classBinding.javaConstructor.getParameterCount() - classBindingOffset;

Label nonNull = new Label();

Expand All @@ -147,6 +176,10 @@ void write(MethodWriter writer, Globals globals) {
writer.newInstance(type);
writer.dup();

if (classBindingOffset == 1) {
writer.loadThis();
}

for (int argument = 0; argument < javaConstructorParameterCount; ++argument) {
arguments.get(argument).write(writer, globals);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ public void analyze(PainlessLookup painlessLookup) {
}
}

Locals locals = Locals.newProgramScope(painlessLookup, methods.values());
Locals locals = Locals.newProgramScope(scriptClassInfo, painlessLookup, methods.values());
analyze(locals);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,5 +266,4 @@ class org.elasticsearch.painless.FeatureTest no_import {
static_import {
int staticAddIntsTest(int, int) from_class org.elasticsearch.painless.StaticTest
float staticAddFloatsTest(float, float) from_class org.elasticsearch.painless.FeatureTest
int testAddWithState(int, int, int, double) bound_to org.elasticsearch.painless.BindingTest
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import org.elasticsearch.painless.spi.Whitelist;
import org.elasticsearch.painless.spi.WhitelistInstanceBinding;
import org.elasticsearch.painless.spi.WhitelistLoader;
import org.elasticsearch.script.ScriptContext;

import java.util.ArrayList;
Expand All @@ -30,6 +31,44 @@

public class BindingsTests extends ScriptTestCase {

public static class BindingTestClass {
public int state;

public BindingTestClass(int state0, int state1) {
this.state = state0 + state1;
}

public int addWithState(int istateless, double dstateless) {
return istateless + state + (int)dstateless;
}
}

public static class ThisBindingTestClass {
private BindingsTestScript bindingsTestScript;
private int state;

public ThisBindingTestClass(BindingsTestScript bindingsTestScript, int state0, int state1) {
this.bindingsTestScript = bindingsTestScript;
this.state = state0 + state1;
}

public int addThisWithState(int istateless, double dstateless) {
return istateless + state + (int)dstateless + bindingsTestScript.getTestValue();
}
}

public static class EmptyThisBindingTestClass {
private BindingsTestScript bindingsTestScript;

public EmptyThisBindingTestClass(BindingsTestScript bindingsTestScript) {
this.bindingsTestScript = bindingsTestScript;
}

public int addEmptyThisWithState(int istateless) {
return istateless + bindingsTestScript.getTestValue();
}
}

public static class InstanceBindingTestClass {
private int value;

Expand All @@ -48,6 +87,7 @@ public int getInstanceBindingValue() {

public abstract static class BindingsTestScript {
public static final String[] PARAMETERS = { "test", "bound" };
public int getTestValue() {return 7;}
public abstract int execute(int test, int bound);
public interface Factory {
BindingsTestScript newInstance();
Expand All @@ -59,6 +99,7 @@ public interface Factory {
protected Map<ScriptContext<?>, List<Whitelist>> scriptContexts() {
Map<ScriptContext<?>, List<Whitelist>> contexts = super.scriptContexts();
List<Whitelist> whitelists = new ArrayList<>(Whitelist.BASE_WHITELISTS);
whitelists.add(WhitelistLoader.loadFromResourceFiles(Whitelist.class, "org.elasticsearch.painless.test"));

InstanceBindingTestClass instanceBindingTestClass = new InstanceBindingTestClass(1);
WhitelistInstanceBinding getter = new WhitelistInstanceBinding("test", instanceBindingTestClass,
Expand All @@ -77,11 +118,15 @@ protected Map<ScriptContext<?>, List<Whitelist>> scriptContexts() {
}

public void testBasicClassBinding() {
assertEquals(15, exec("testAddWithState(4, 5, 6, 0.0)"));
String script = "addWithState(4, 5, 6, 0.0)";
BindingsTestScript.Factory factory = scriptEngine.compile(null, script, BindingsTestScript.CONTEXT, Collections.emptyMap());
BindingsTestScript executableScript = factory.newInstance();

assertEquals(15, executableScript.execute(0, 0));
}

public void testRepeatedClassBinding() {
String script = "testAddWithState(4, 5, test, 0.0)";
String script = "addWithState(4, 5, test, 0.0)";
BindingsTestScript.Factory factory = scriptEngine.compile(null, script, BindingsTestScript.CONTEXT, Collections.emptyMap());
BindingsTestScript executableScript = factory.newInstance();

Expand All @@ -91,14 +136,34 @@ public void testRepeatedClassBinding() {
}

public void testBoundClassBinding() {
String script = "testAddWithState(4, bound, test, 0.0)";
String script = "addWithState(4, bound, test, 0.0)";
BindingsTestScript.Factory factory = scriptEngine.compile(null, script, BindingsTestScript.CONTEXT, Collections.emptyMap());
BindingsTestScript executableScript = factory.newInstance();

assertEquals(10, executableScript.execute(5, 1));
assertEquals(9, executableScript.execute(4, 2));
}

public void testThisClassBinding() {
String script = "addThisWithState(4, bound, test, 0.0)";

BindingsTestScript.Factory factory = scriptEngine.compile(null, script, BindingsTestScript.CONTEXT, Collections.emptyMap());
BindingsTestScript executableScript = factory.newInstance();

assertEquals(17, executableScript.execute(5, 1));
assertEquals(16, executableScript.execute(4, 2));
}

public void testEmptyThisClassBinding() {
String script = "addEmptyThisWithState(test)";

BindingsTestScript.Factory factory = scriptEngine.compile(null, script, BindingsTestScript.CONTEXT, Collections.emptyMap());
BindingsTestScript executableScript = factory.newInstance();

assertEquals(8, executableScript.execute(1, 0));
assertEquals(9, executableScript.execute(2, 0));
}

public void testInstanceBinding() {
String script = "getInstanceBindingValue() + test + bound";
BindingsTestScript.Factory factory = scriptEngine.compile(null, script, BindingsTestScript.CONTEXT, Collections.emptyMap());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# whitelist for tests
class org.elasticsearch.painless.BindingsTests$BindingsTestScript {
}

static_import {
int addWithState(int, int, int, double) bound_to org.elasticsearch.painless.BindingsTests$BindingTestClass
int addThisWithState(BindingsTests.BindingsTestScript, int, int, int, double) bound_to org.elasticsearch.painless.BindingsTests$ThisBindingTestClass
int addEmptyThisWithState(BindingsTests.BindingsTestScript, int) bound_to org.elasticsearch.painless.BindingsTests$EmptyThisBindingTestClass
}

0 comments on commit 8da56dc

Please sign in to comment.