Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scripting: Fix painless compiler loader to know about context classes #32385

Merged
merged 5 commits into from
Jul 31, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,14 @@ final class Compiler {
/**
* A secure class loader used to define Painless scripts.
*/
static final class Loader extends SecureClassLoader {
final class Loader extends SecureClassLoader {
private final AtomicInteger lambdaCounter = new AtomicInteger(0);
private final PainlessLookup painlessLookup;

/**
* @param parent The parent ClassLoader.
*/
Loader(ClassLoader parent, PainlessLookup painlessLookup) {
Loader(ClassLoader parent) {
super(parent);

this.painlessLookup = painlessLookup;
}

/**
Expand All @@ -90,6 +87,15 @@ static final class Loader extends SecureClassLoader {
*/
@Override
public Class<?> findClass(String name) throws ClassNotFoundException {
if (scriptClass.getName().equals(name)) {
return scriptClass;
}
if (factoryClass != null && factoryClass.getName().equals(name)) {
return factoryClass;
}
if (statefulFactoryClass != null && statefulFactoryClass.getName().equals(name)) {
return statefulFactoryClass;
}
Class<?> found = painlessLookup.getClassFromBinaryName(name);

return found != null ? found : super.findClass(name);
Expand Down Expand Up @@ -139,13 +145,23 @@ int newLambdaIdentifier() {
* {@link Compiler}'s specified {@link PainlessLookup}.
*/
public Loader createLoader(ClassLoader parent) {
return new Loader(parent, painlessLookup);
return new Loader(parent);
}

/**
* The class/interface the script is guaranteed to derive/implement.
* The class/interface the script will implement.
*/
private final Class<?> scriptClass;

/**
* The class/interface to create the {@code scriptClass} instance.
*/
private final Class<?> factoryClass;

/**
* An optional class/interface to create the {@code factoryClass} instance.
*/
private final Class<?> base;
private final Class<?> statefulFactoryClass;

/**
* The whitelist the script will use.
Expand All @@ -154,11 +170,15 @@ public Loader createLoader(ClassLoader parent) {

/**
* Standard constructor.
* @param base The class/interface the script is guaranteed to derive/implement.
* @param scriptClass The class/interface the script will implement.
* @param factoryClass An optional class/interface to create the {@code scriptClass} instance.
* @param statefulFactoryClass An optional class/interface to create the {@code factoryClass} instance.
* @param painlessLookup The whitelist the script will use.
*/
Compiler(Class<?> base, PainlessLookup painlessLookup) {
this.base = base;
Compiler(Class<?> scriptClass, Class<?> factoryClass, Class<?> statefulFactoryClass, PainlessLookup painlessLookup) {
this.scriptClass = scriptClass;
this.factoryClass = factoryClass;
this.statefulFactoryClass = statefulFactoryClass;
this.painlessLookup = painlessLookup;
}

Expand All @@ -177,7 +197,7 @@ Constructor<?> compile(Loader loader, MainMethodReserved reserved, String name,
" plugin if a script longer than this length is a requirement.");
}

ScriptClassInfo scriptClassInfo = new ScriptClassInfo(painlessLookup, base);
ScriptClassInfo scriptClassInfo = new ScriptClassInfo(painlessLookup, scriptClass);
SSource root = Walker.buildPainlessTree(scriptClassInfo, reserved, name, source, settings, painlessLookup,
null);
root.analyze(painlessLookup);
Expand Down Expand Up @@ -209,7 +229,7 @@ byte[] compile(String name, String source, CompilerSettings settings, Printer de
" plugin if a script longer than this length is a requirement.");
}

ScriptClassInfo scriptClassInfo = new ScriptClassInfo(painlessLookup, base);
ScriptClassInfo scriptClassInfo = new ScriptClassInfo(painlessLookup, scriptClass);
SSource root = Walker.buildPainlessTree(scriptClassInfo, new MainMethodReserved(), name, source, settings, painlessLookup,
debugStream);
root.analyze(painlessLookup);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import java.security.PrivilegedAction;

import static java.lang.invoke.MethodHandles.Lookup;
import static org.elasticsearch.painless.Compiler.Loader;
import static org.elasticsearch.painless.WriterConstants.CLASS_VERSION;
import static org.elasticsearch.painless.WriterConstants.CTOR_METHOD_NAME;
import static org.elasticsearch.painless.WriterConstants.DELEGATE_BOOTSTRAP_HANDLE;
Expand Down Expand Up @@ -207,7 +206,7 @@ public static CallSite lambdaBootstrap(
MethodType delegateMethodType,
int isDelegateInterface)
throws LambdaConversionException {
Loader loader = (Loader)lookup.lookupClass().getClassLoader();
Compiler.Loader loader = (Compiler.Loader)lookup.lookupClass().getClassLoader();
String lambdaClassName = Type.getInternalName(lookup.lookupClass()) + "$$Lambda" + loader.newLambdaIdentifier();
Type lambdaClassType = Type.getObjectType(lambdaClassName);
Type delegateClassType = Type.getObjectType(delegateClassName.replace('.', '/'));
Expand Down Expand Up @@ -457,11 +456,11 @@ private static void endLambdaClass(ClassWriter cw) {
}

/**
* Defines the {@link Class} for the lambda class using the same {@link Loader}
* Defines the {@link Class} for the lambda class using the same {@link Compiler.Loader}
* that originally defined the class for the Painless script.
*/
private static Class<?> createLambdaClass(
Loader loader,
Compiler.Loader loader,
ClassWriter cw,
Type lambdaClassType) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@ public PainlessScriptEngine(Settings settings, Map<ScriptContext<?>, List<Whitel
for (Map.Entry<ScriptContext<?>, List<Whitelist>> entry : contexts.entrySet()) {
ScriptContext<?> context = entry.getKey();
if (context.instanceClazz.equals(SearchScript.class) || context.instanceClazz.equals(ExecutableScript.class)) {
contextsToCompilers.put(context, new Compiler(GenericElasticsearchScript.class,
contextsToCompilers.put(context, new Compiler(GenericElasticsearchScript.class, null, null,
PainlessLookupBuilder.buildFromWhitelists(entry.getValue())));
} else {
contextsToCompilers.put(context, new Compiler(context.instanceClazz,
contextsToCompilers.put(context, new Compiler(context.instanceClazz, context.factoryClazz, context.statefulFactoryClazz,
PainlessLookupBuilder.buildFromWhitelists(entry.getValue())));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public Map<String, Object> getTestMap() {
}

public void testGets() {
Compiler compiler = new Compiler(Gets.class, painlessLookup);
Compiler compiler = new Compiler(Gets.class, null, null, painlessLookup);
Map<String, Object> map = new HashMap<>();
map.put("s", 1);

Expand All @@ -87,7 +87,7 @@ public abstract static class NoArgs {
public abstract Object execute();
}
public void testNoArgs() {
Compiler compiler = new Compiler(NoArgs.class, painlessLookup);
Compiler compiler = new Compiler(NoArgs.class, null, null, painlessLookup);
assertEquals(1, ((NoArgs)scriptEngine.compile(compiler, null, "1", emptyMap())).execute());
assertEquals("foo", ((NoArgs)scriptEngine.compile(compiler, null, "'foo'", emptyMap())).execute());

Expand All @@ -111,13 +111,13 @@ public abstract static class OneArg {
public abstract Object execute(Object arg);
}
public void testOneArg() {
Compiler compiler = new Compiler(OneArg.class, painlessLookup);
Compiler compiler = new Compiler(OneArg.class, null, null, painlessLookup);
Object rando = randomInt();
assertEquals(rando, ((OneArg)scriptEngine.compile(compiler, null, "arg", emptyMap())).execute(rando));
rando = randomAlphaOfLength(5);
assertEquals(rando, ((OneArg)scriptEngine.compile(compiler, null, "arg", emptyMap())).execute(rando));

Compiler noargs = new Compiler(NoArgs.class, painlessLookup);
Compiler noargs = new Compiler(NoArgs.class, null, null, painlessLookup);
Exception e = expectScriptThrows(IllegalArgumentException.class, () ->
scriptEngine.compile(noargs, null, "doc", emptyMap()));
assertEquals("Variable [doc] is not defined.", e.getMessage());
Expand All @@ -132,7 +132,7 @@ public abstract static class ArrayArg {
public abstract Object execute(String[] arg);
}
public void testArrayArg() {
Compiler compiler = new Compiler(ArrayArg.class, painlessLookup);
Compiler compiler = new Compiler(ArrayArg.class, null, null, painlessLookup);
String rando = randomAlphaOfLength(5);
assertEquals(rando, ((ArrayArg)scriptEngine.compile(compiler, null, "arg[0]", emptyMap())).execute(new String[] {rando, "foo"}));
}
Expand All @@ -142,7 +142,7 @@ public abstract static class PrimitiveArrayArg {
public abstract Object execute(int[] arg);
}
public void testPrimitiveArrayArg() {
Compiler compiler = new Compiler(PrimitiveArrayArg.class, painlessLookup);
Compiler compiler = new Compiler(PrimitiveArrayArg.class, null, null, painlessLookup);
int rando = randomInt();
assertEquals(rando, ((PrimitiveArrayArg)scriptEngine.compile(compiler, null, "arg[0]", emptyMap())).execute(new int[] {rando, 10}));
}
Expand All @@ -152,7 +152,7 @@ public abstract static class DefArrayArg {
public abstract Object execute(Object[] arg);
}
public void testDefArrayArg() {
Compiler compiler = new Compiler(DefArrayArg.class, painlessLookup);
Compiler compiler = new Compiler(DefArrayArg.class, null, null, painlessLookup);
Object rando = randomInt();
assertEquals(rando, ((DefArrayArg)scriptEngine.compile(compiler, null, "arg[0]", emptyMap())).execute(new Object[] {rando, 10}));
rando = randomAlphaOfLength(5);
Expand All @@ -170,7 +170,7 @@ public abstract static class ManyArgs {
public abstract boolean needsD();
}
public void testManyArgs() {
Compiler compiler = new Compiler(ManyArgs.class, painlessLookup);
Compiler compiler = new Compiler(ManyArgs.class, null, null, painlessLookup);
int rando = randomInt();
assertEquals(rando, ((ManyArgs)scriptEngine.compile(compiler, null, "a", emptyMap())).execute(rando, 0, 0, 0));
assertEquals(10, ((ManyArgs)scriptEngine.compile(compiler, null, "a + b + c + d", emptyMap())).execute(1, 2, 3, 4));
Expand Down Expand Up @@ -198,7 +198,7 @@ public abstract static class VarargTest {
public abstract Object execute(String... arg);
}
public void testVararg() {
Compiler compiler = new Compiler(VarargTest.class, painlessLookup);
Compiler compiler = new Compiler(VarargTest.class, null, null, painlessLookup);
assertEquals("foo bar baz", ((VarargTest)scriptEngine.compile(compiler, null, "String.join(' ', Arrays.asList(arg))", emptyMap()))
.execute("foo", "bar", "baz"));
}
Expand All @@ -214,7 +214,7 @@ public Object executeWithASingleOne(int a, int b, int c) {
}
}
public void testDefaultMethods() {
Compiler compiler = new Compiler(DefaultMethods.class, painlessLookup);
Compiler compiler = new Compiler(DefaultMethods.class, null, null, painlessLookup);
int rando = randomInt();
assertEquals(rando, ((DefaultMethods)scriptEngine.compile(compiler, null, "a", emptyMap())).execute(rando, 0, 0, 0));
assertEquals(rando, ((DefaultMethods)scriptEngine.compile(compiler, null, "a", emptyMap())).executeWithASingleOne(rando, 0, 0));
Expand All @@ -228,7 +228,7 @@ public abstract static class ReturnsVoid {
public abstract void execute(Map<String, Object> map);
}
public void testReturnsVoid() {
Compiler compiler = new Compiler(ReturnsVoid.class, painlessLookup);
Compiler compiler = new Compiler(ReturnsVoid.class, null, null, painlessLookup);
Map<String, Object> map = new HashMap<>();
((ReturnsVoid)scriptEngine.compile(compiler, null, "map.a = 'foo'", emptyMap())).execute(map);
assertEquals(singletonMap("a", "foo"), map);
Expand All @@ -247,7 +247,7 @@ public abstract static class ReturnsPrimitiveBoolean {
public abstract boolean execute();
}
public void testReturnsPrimitiveBoolean() {
Compiler compiler = new Compiler(ReturnsPrimitiveBoolean.class, painlessLookup);
Compiler compiler = new Compiler(ReturnsPrimitiveBoolean.class, null, null, painlessLookup);

assertEquals(true, ((ReturnsPrimitiveBoolean)scriptEngine.compile(compiler, null, "true", emptyMap())).execute());
assertEquals(false, ((ReturnsPrimitiveBoolean)scriptEngine.compile(compiler, null, "false", emptyMap())).execute());
Expand Down Expand Up @@ -289,7 +289,7 @@ public abstract static class ReturnsPrimitiveInt {
public abstract int execute();
}
public void testReturnsPrimitiveInt() {
Compiler compiler = new Compiler(ReturnsPrimitiveInt.class, painlessLookup);
Compiler compiler = new Compiler(ReturnsPrimitiveInt.class, null, null, painlessLookup);

assertEquals(1, ((ReturnsPrimitiveInt)scriptEngine.compile(compiler, null, "1", emptyMap())).execute());
assertEquals(1, ((ReturnsPrimitiveInt)scriptEngine.compile(compiler, null, "(int) 1L", emptyMap())).execute());
Expand Down Expand Up @@ -331,7 +331,7 @@ public abstract static class ReturnsPrimitiveFloat {
public abstract float execute();
}
public void testReturnsPrimitiveFloat() {
Compiler compiler = new Compiler(ReturnsPrimitiveFloat.class, painlessLookup);
Compiler compiler = new Compiler(ReturnsPrimitiveFloat.class, null, null, painlessLookup);

assertEquals(1.1f, ((ReturnsPrimitiveFloat)scriptEngine.compile(compiler, null, "1.1f", emptyMap())).execute(), 0);
assertEquals(1.1f, ((ReturnsPrimitiveFloat)scriptEngine.compile(compiler, null, "(float) 1.1d", emptyMap())).execute(), 0);
Expand Down Expand Up @@ -362,7 +362,7 @@ public abstract static class ReturnsPrimitiveDouble {
public abstract double execute();
}
public void testReturnsPrimitiveDouble() {
Compiler compiler = new Compiler(ReturnsPrimitiveDouble.class, painlessLookup);
Compiler compiler = new Compiler(ReturnsPrimitiveDouble.class, null, null, painlessLookup);

assertEquals(1.0, ((ReturnsPrimitiveDouble)scriptEngine.compile(compiler, null, "1", emptyMap())).execute(), 0);
assertEquals(1.0, ((ReturnsPrimitiveDouble)scriptEngine.compile(compiler, null, "1L", emptyMap())).execute(), 0);
Expand Down Expand Up @@ -396,7 +396,7 @@ public abstract static class NoArgumentsConstant {
public abstract Object execute(String foo);
}
public void testNoArgumentsConstant() {
Compiler compiler = new Compiler(NoArgumentsConstant.class, painlessLookup);
Compiler compiler = new Compiler(NoArgumentsConstant.class, null, null, painlessLookup);
Exception e = expectScriptThrows(IllegalArgumentException.class, false, () ->
scriptEngine.compile(compiler, null, "1", emptyMap()));
assertThat(e.getMessage(), startsWith(
Expand All @@ -409,7 +409,7 @@ public abstract static class WrongArgumentsConstant {
public abstract Object execute(String foo);
}
public void testWrongArgumentsConstant() {
Compiler compiler = new Compiler(WrongArgumentsConstant.class, painlessLookup);
Compiler compiler = new Compiler(WrongArgumentsConstant.class, null, null, painlessLookup);
Exception e = expectScriptThrows(IllegalArgumentException.class, false, () ->
scriptEngine.compile(compiler, null, "1", emptyMap()));
assertThat(e.getMessage(), startsWith(
Expand All @@ -422,7 +422,7 @@ public abstract static class WrongLengthOfArgumentConstant {
public abstract Object execute(String foo);
}
public void testWrongLengthOfArgumentConstant() {
Compiler compiler = new Compiler(WrongLengthOfArgumentConstant.class, painlessLookup);
Compiler compiler = new Compiler(WrongLengthOfArgumentConstant.class, null, null, painlessLookup);
Exception e = expectScriptThrows(IllegalArgumentException.class, false, () ->
scriptEngine.compile(compiler, null, "1", emptyMap()));
assertThat(e.getMessage(), startsWith("[" + WrongLengthOfArgumentConstant.class.getName() + "#ARGUMENTS] has length [2] but ["
Expand All @@ -434,7 +434,7 @@ public abstract static class UnknownArgType {
public abstract Object execute(UnknownArgType foo);
}
public void testUnknownArgType() {
Compiler compiler = new Compiler(UnknownArgType.class, painlessLookup);
Compiler compiler = new Compiler(UnknownArgType.class, null, null, painlessLookup);
Exception e = expectScriptThrows(IllegalArgumentException.class, false, () ->
scriptEngine.compile(compiler, null, "1", emptyMap()));
assertEquals("[foo] is of unknown type [" + UnknownArgType.class.getName() + ". Painless interfaces can only accept arguments "
Expand All @@ -446,7 +446,7 @@ public abstract static class UnknownReturnType {
public abstract UnknownReturnType execute(String foo);
}
public void testUnknownReturnType() {
Compiler compiler = new Compiler(UnknownReturnType.class, painlessLookup);
Compiler compiler = new Compiler(UnknownReturnType.class, null, null, painlessLookup);
Exception e = expectScriptThrows(IllegalArgumentException.class, false, () ->
scriptEngine.compile(compiler, null, "1", emptyMap()));
assertEquals("Painless can only implement execute methods returning a whitelisted type but [" + UnknownReturnType.class.getName()
Expand All @@ -458,7 +458,7 @@ public abstract static class UnknownArgTypeInArray {
public abstract Object execute(UnknownArgTypeInArray[] foo);
}
public void testUnknownArgTypeInArray() {
Compiler compiler = new Compiler(UnknownArgTypeInArray.class, painlessLookup);
Compiler compiler = new Compiler(UnknownArgTypeInArray.class, null, null, painlessLookup);
Exception e = expectScriptThrows(IllegalArgumentException.class, false, () ->
scriptEngine.compile(compiler, null, "1", emptyMap()));
assertEquals("[foo] is of unknown type [" + UnknownArgTypeInArray.class.getName() + ". Painless interfaces can only accept "
Expand All @@ -470,7 +470,7 @@ public abstract static class TwoExecuteMethods {
public abstract Object execute(boolean foo);
}
public void testTwoExecuteMethods() {
Compiler compiler = new Compiler(TwoExecuteMethods.class, painlessLookup);
Compiler compiler = new Compiler(TwoExecuteMethods.class, null, null, painlessLookup);
Exception e = expectScriptThrows(IllegalArgumentException.class, false, () ->
scriptEngine.compile(compiler, null, "null", emptyMap()));
assertEquals("Painless can only implement interfaces that have a single method named [execute] but ["
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ static String toString(Class<?> iface, String source, CompilerSettings settings)
PrintWriter outputWriter = new PrintWriter(output);
Textifier textifier = new Textifier();
try {
new Compiler(iface, PainlessLookupBuilder.buildFromWhitelists(Whitelist.BASE_WHITELISTS))
new Compiler(iface, null, null, PainlessLookupBuilder.buildFromWhitelists(Whitelist.BASE_WHITELISTS))
.compile("<debugging>", source, settings, textifier);
} catch (RuntimeException e) {
textifier.print(outputWriter);
Expand Down