Skip to content

Commit

Permalink
Simpler regex constants in painless (elastic#68486)
Browse files Browse the repository at this point in the history
Replaces the double `Pattern.compile` invocations in painless scripts
with the fancy constant injection we added in elastic#68088. This caused one of
the tests to fail. It turns out that we weren't fully iterating the IR
tree during the constant folding phases. I started experimenting and
added a ton of tests that failed. Then I fixed them by changing the IR
tree walking code.
  • Loading branch information
nik9000 authored Feb 3, 2021
1 parent 92b59d9 commit e686e18
Show file tree
Hide file tree
Showing 14 changed files with 80 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public <Scope> void visit(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {

@Override
public <Scope> void visitChildren(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {
// do nothing; terminal node
getChildNode().visit(irTreeVisitor, scope);
}

/* ---- end visitor ---- */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public <Scope> void visit(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {

@Override
public <Scope> void visitChildren(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {
// do nothing; terminal node
getChildNode().visit(irTreeVisitor, scope);
}

/* ---- end visitor ---- */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public <Scope> void visit(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {

@Override
public <Scope> void visitChildren(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {
// do nothing; terminal node
getChildNode().visit(irTreeVisitor, scope);
}

/* ---- end visitor ---- */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public <Scope> void visit(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {

@Override
public <Scope> void visitChildren(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {
// do nothing; terminal node
getChildNode().visit(irTreeVisitor, scope);
}

/* ---- end visitor ---- */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public <Scope> void visit(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {

@Override
public <Scope> void visitChildren(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {
// do nothing; terminal node
getChildNode().visit(irTreeVisitor, scope);
}

/* ---- end visitor ---- */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public <Scope> void visit(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {

@Override
public <Scope> void visitChildren(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {
// do nothing; terminal node
getChildNode().visit(irTreeVisitor, scope);
}

/* ---- end visitor ---- */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public <Scope> void visit(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {

@Override
public <Scope> void visitChildren(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {
// do nothing; terminal node
getChildNode().visit(irTreeVisitor, scope);
}

/* ---- end visitor ---- */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public <Scope> void visit(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {

@Override
public <Scope> void visitChildren(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {
// do nothing; terminal node
getChildNode().visit(irTreeVisitor, scope);
}

/* ---- end visitor ---- */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public <Scope> void visit(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {

@Override
public <Scope> void visitChildren(IRTreeVisitor<Scope> irTreeVisitor, Scope scope) {
// do nothing; terminal node
getChildNode().visit(irTreeVisitor, scope);
}

/* ---- end visitor ---- */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1216,6 +1216,10 @@ public void visitConstant(ConstantNode irConstantNode, WriteScope writeScope) {
*/
String fieldName = irConstantNode.getDecorationValue(IRDConstantFieldName.class);
Type asmFieldType = MethodWriter.getType(irConstantNode.getDecorationValue(IRDExpressionType.class));
if (asmFieldType == null) {
throw irConstantNode.getLocation()
.createError(new IllegalStateException("Didn't attach constant to [" + irConstantNode + "]"));
}
methodWriter.getStatic(CLASS_TYPE, fieldName, asmFieldType);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2051,51 +2051,52 @@ public void visitRegex(ERegex userRegexNode, SemanticScope semanticScope) {

Location location = userRegexNode.getLocation();

int constant = 0;
int regexFlags = 0;

for (int i = 0; i < flags.length(); ++i) {
char flag = flags.charAt(i);

switch (flag) {
case 'c':
constant |= Pattern.CANON_EQ;
regexFlags |= Pattern.CANON_EQ;
break;
case 'i':
constant |= Pattern.CASE_INSENSITIVE;
regexFlags |= Pattern.CASE_INSENSITIVE;
break;
case 'l':
constant |= Pattern.LITERAL;
regexFlags |= Pattern.LITERAL;
break;
case 'm':
constant |= Pattern.MULTILINE;
regexFlags |= Pattern.MULTILINE;
break;
case 's':
constant |= Pattern.DOTALL;
regexFlags |= Pattern.DOTALL;
break;
case 'U':
constant |= Pattern.UNICODE_CHARACTER_CLASS;
regexFlags |= Pattern.UNICODE_CHARACTER_CLASS;
break;
case 'u':
constant |= Pattern.UNICODE_CASE;
regexFlags |= Pattern.UNICODE_CASE;
break;
case 'x':
constant |= Pattern.COMMENTS;
regexFlags |= Pattern.COMMENTS;
break;
default:
throw new IllegalArgumentException("invalid regular expression: unknown flag [" + flag + "]");
}
}

Pattern compiled;
try {
Pattern.compile(pattern, constant);
compiled = Pattern.compile(pattern, regexFlags);
} catch (PatternSyntaxException pse) {
throw new Location(location.getSourceName(), location.getOffset() + 1 + pse.getIndex()).createError(
new IllegalArgumentException("invalid regular expression: " +
"could not compile regex constant [" + pattern + "] with flags [" + flags + "]", pse));
}

semanticScope.putDecoration(userRegexNode, new ValueType(Pattern.class));
semanticScope.putDecoration(userRegexNode, new StandardConstant(constant));
semanticScope.putDecoration(userRegexNode, new StandardConstant(compiled));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@
import org.elasticsearch.painless.ir.StoreDotDefNode;
import org.elasticsearch.painless.ir.StoreDotNode;
import org.elasticsearch.painless.ir.StoreDotShortcutNode;
import org.elasticsearch.painless.ir.StoreFieldMemberNode;
import org.elasticsearch.painless.ir.StoreListShortcutNode;
import org.elasticsearch.painless.ir.StoreMapShortcutNode;
import org.elasticsearch.painless.ir.StoreVariableNode;
Expand Down Expand Up @@ -206,8 +205,8 @@
import org.elasticsearch.painless.symbol.IRDecorations.IRDConstant;
import org.elasticsearch.painless.symbol.IRDecorations.IRDConstructor;
import org.elasticsearch.painless.symbol.IRDecorations.IRDDeclarationType;
import org.elasticsearch.painless.symbol.IRDecorations.IRDDepth;
import org.elasticsearch.painless.symbol.IRDecorations.IRDDefReferenceEncoding;
import org.elasticsearch.painless.symbol.IRDecorations.IRDDepth;
import org.elasticsearch.painless.symbol.IRDecorations.IRDExceptionType;
import org.elasticsearch.painless.symbol.IRDecorations.IRDExpressionType;
import org.elasticsearch.painless.symbol.IRDecorations.IRDField;
Expand Down Expand Up @@ -1321,77 +1320,10 @@ public void visitNull(ENull userNullNode, ScriptScope scriptScope) {

@Override
public void visitRegex(ERegex userRegexNode, ScriptScope scriptScope) {
String memberFieldName = scriptScope.getNextSyntheticName("regex");

FieldNode irFieldNode = new FieldNode(userRegexNode.getLocation());
irFieldNode.attachDecoration(new IRDModifiers(Modifier.FINAL | Modifier.STATIC | Modifier.PRIVATE));
irFieldNode.attachDecoration(new IRDFieldType(Pattern.class));
irFieldNode.attachDecoration(new IRDName(memberFieldName));

irClassNode.addFieldNode(irFieldNode);

try {
StatementExpressionNode irStatementExpressionNode = new StatementExpressionNode(userRegexNode.getLocation());

BlockNode blockNode = irClassNode.getClinitBlockNode();
blockNode.addStatementNode(irStatementExpressionNode);

StoreFieldMemberNode irStoreFieldMemberNode = new StoreFieldMemberNode(userRegexNode.getLocation());
irStoreFieldMemberNode.attachDecoration(new IRDExpressionType(void.class));
irStoreFieldMemberNode.attachDecoration(new IRDStoreType(Pattern.class));
irStoreFieldMemberNode.attachDecoration(new IRDName(memberFieldName));
irStoreFieldMemberNode.attachCondition(IRCStatic.class);

irStatementExpressionNode.setExpressionNode(irStoreFieldMemberNode);

BinaryImplNode irBinaryImplNode = new BinaryImplNode(userRegexNode.getLocation());
irBinaryImplNode.attachDecoration(new IRDExpressionType(Pattern.class));

irStoreFieldMemberNode.setChildNode(irBinaryImplNode);

StaticNode irStaticNode = new StaticNode(userRegexNode.getLocation());
irStaticNode.attachDecoration(new IRDExpressionType(Pattern.class));

irBinaryImplNode.setLeftNode(irStaticNode);

InvokeCallNode invokeCallNode = new InvokeCallNode(userRegexNode.getLocation());
invokeCallNode.attachDecoration(new IRDExpressionType(Pattern.class));
invokeCallNode.setBox(Pattern.class);
invokeCallNode.setMethod(new PainlessMethod(
Pattern.class.getMethod("compile", String.class, int.class),
Pattern.class,
Pattern.class,
Arrays.asList(String.class, int.class),
null,
null,
null
)
);

irBinaryImplNode.setRightNode(invokeCallNode);

ConstantNode irConstantNode = new ConstantNode(userRegexNode.getLocation());
irConstantNode.attachDecoration(new IRDExpressionType(String.class));
irConstantNode.attachDecoration(new IRDConstant(userRegexNode.getPattern()));

invokeCallNode.addArgumentNode(irConstantNode);

irConstantNode = new ConstantNode(userRegexNode.getLocation());
irConstantNode.attachDecoration(new IRDExpressionType(int.class));
irConstantNode.attachDecoration(
new IRDConstant(scriptScope.getDecoration(userRegexNode, StandardConstant.class).getStandardConstant()));

invokeCallNode.addArgumentNode(irConstantNode);
} catch (Exception exception) {
throw userRegexNode.createError(new IllegalStateException("illegal tree structure"));
}

LoadFieldMemberNode irLoadFieldMemberNode = new LoadFieldMemberNode(userRegexNode.getLocation());
irLoadFieldMemberNode.attachDecoration(new IRDExpressionType(Pattern.class));
irLoadFieldMemberNode.attachDecoration(new IRDName(memberFieldName));
irLoadFieldMemberNode.attachCondition(IRCStatic.class);

scriptScope.putDecoration(userRegexNode, new IRNodeDecoration(irLoadFieldMemberNode));
ConstantNode constant = new ConstantNode(userRegexNode.getLocation());
constant.attachDecoration(new IRDExpressionType(Pattern.class));
constant.attachDecoration(new IRDConstant(scriptScope.getDecoration(userRegexNode, StandardConstant.class).getStandardConstant()));
scriptScope.putDecoration(userRegexNode, new IRNodeDecoration(constant));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,28 @@ public void testCast() {
assertBytecodeExists("2+'2D'", "LDC \"22D\"");
assertBytecodeExists("4L<5F", "ICONST_1");
}

public void testStoreInMap() {
assertBytecodeExists("Map m = [:]; m.a = 1 + 1; m.a", "ICONST_2");
}

public void testStoreInMapDef() {
assertBytecodeExists("def m = [:]; m.a = 1 + 1; m.a", "ICONST_2");
}

public void testStoreInList() {
assertBytecodeExists("List l = [null]; l.0 = 1 + 1; l.0", "ICONST_2");
}

public void testStoreInListDef() {
assertBytecodeExists("def l = [null]; l.0 = 1 + 1; l.0", "ICONST_2");
}

public void testStoreInArray() {
assertBytecodeExists("int[] a = new int[1]; a[0] = 1 + 1; a[0]", "ICONST_2");
}

public void testStoreInArrayDef() {
assertBytecodeExists("def a = new int[1]; a[0] = 1 + 1; a[0]", "ICONST_2");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ public void testPatternAfterUnaryNotBoolean() {

public void testInTernaryCondition() {
assertEquals(true, exec("return /foo/.matcher('foo').matches() ? true : false"));
assertEquals(1, exec("def i = 0; i += /foo/.matcher('foo').matches() ? 1 : 1; return i"));
assertEquals(1, exec("def i = 0; i += /foo/.matcher('foo').matches() ? 1 : 0; return i"));
assertEquals(true, exec("return 'foo' ==~ /foo/ ? true : false"));
assertEquals(1, exec("def i = 0; i += 'foo' ==~ /foo/ ? 1 : 1; return i"));
assertEquals(1, exec("def i = 0; i += 'foo' ==~ /foo/ ? 1 : 0; return i"));
}

public void testInTernaryTrueArm() {
Expand Down Expand Up @@ -232,6 +232,30 @@ public void testReplaceFirstQuoteReplacement() {
exec("'the quick brown fox'.replaceFirst(/[aeiou]/, m -> '$' + m.group().toUpperCase(Locale.ROOT))"));
}

public void testStoreInMap() {
assertEquals(true, exec("Map m = [:]; m.a = /foo/; m.a.matcher('foo').matches()"));
}

public void testStoreInMapDef() {
assertEquals(true, exec("def m = [:]; m.a = /foo/; m.a.matcher('foo').matches()"));
}

public void testStoreInList() {
assertEquals(true, exec("List l = [null]; l.0 = /foo/; l.0.matcher('foo').matches()"));
}

public void testStoreInListDef() {
assertEquals(true, exec("def l = [null]; l.0 = /foo/; l.0.matcher('foo').matches()"));
}

public void testStoreInArray() {
assertEquals(true, exec("Pattern[] a = new Pattern[1]; a[0] = /foo/; a[0].matcher('foo').matches()"));
}

public void testStoreInArrayDef() {
assertEquals(true, exec("def a = new Pattern[1]; a[0] = /foo/; a[0].matcher('foo').matches()"));
}

public void testCantUsePatternCompile() {
IllegalArgumentException e = expectScriptThrows(IllegalArgumentException.class, () -> {
exec("Pattern.compile('aa')");
Expand Down

0 comments on commit e686e18

Please sign in to comment.