Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increase ATN states size limit, simplify ATN serialization #3546

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -366,9 +366,10 @@ public static RuntimeTestDescriptor[] getRuntimeTestDescriptors(String group, St
}

if (group.equals("LexerExec")) {
descriptors.add(GeneratedLexerDescriptors.getLineSeparatorLfTest(targetName));
descriptors.add(GeneratedLexerDescriptors.getLineSeparatorCrLfTest(targetName));
descriptors.add(GeneratedLexerDescriptors.getLineSeparatorLfDescriptor(targetName));
descriptors.add(GeneratedLexerDescriptors.getLineSeparatorCrLfDescriptor(targetName));
descriptors.add(GeneratedLexerDescriptors.getLargeLexerDescriptor(targetName));
descriptors.add(GeneratedLexerDescriptors.getAtnStatesSizeMoreThan65535Descriptor(targetName));
}

return descriptors.toArray(new RuntimeTestDescriptor[0]);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package org.antlr.v4.test.runtime;

import java.util.Collections;

public class GeneratedLexerDescriptors {
static RuntimeTestDescriptor getLineSeparatorLfTest(String targetName) {
static RuntimeTestDescriptor getLineSeparatorLfDescriptor(String targetName) {
UniversalRuntimeTestDescriptor result = new UniversalRuntimeTestDescriptor();
result.name = "LineSeparatorLf";
result.targetName = targetName;
Expand All @@ -20,7 +22,7 @@ static RuntimeTestDescriptor getLineSeparatorLfTest(String targetName) {
return result;
}

static RuntimeTestDescriptor getLineSeparatorCrLfTest(String targetName) {
static RuntimeTestDescriptor getLineSeparatorCrLfDescriptor(String targetName) {
UniversalRuntimeTestDescriptor result = new UniversalRuntimeTestDescriptor();
result.name = "LineSeparatorCrLf";
result.targetName = targetName;
Expand Down Expand Up @@ -65,4 +67,50 @@ static RuntimeTestDescriptor getLargeLexerDescriptor(String targetName) {
"[@1,5:4='<EOF>',<-1>,1:5]\n";
return result;
}

static RuntimeTestDescriptor getAtnStatesSizeMoreThan65535Descriptor(String targetName) {
UniversalRuntimeTestDescriptor result = new UniversalRuntimeTestDescriptor();
result.name = "AtnStatesSizeMoreThan65535";
result.notes = "Regression for https://github.com/antlr/antlr4/issues/1863";
result.targetName = targetName;
result.testType = "Lexer";

final int tokensCount = 1024;
final String suffix = String.join("", Collections.nCopies(70, "_"));

String grammarName = "L";
StringBuilder grammar = new StringBuilder();
grammar.append("lexer grammar ").append(grammarName).append(";\n");
grammar.append('\n');
StringBuilder input = new StringBuilder();
StringBuilder output = new StringBuilder();
int startOffset;
int stopOffset = -2;
for (int i = 0; i < tokensCount; i++) {
String value = "T_" + i + suffix;
grammar.append(value).append(": '").append(value).append("';\n");
input.append(value).append('\n');

startOffset = stopOffset + 2;
stopOffset += value.length() + 1;

output.append("[@").append(i).append(',').append(startOffset).append(':').append(stopOffset)
.append("='").append(value).append("',<").append(i + 1).append(">,").append(i + 1)
.append(":0]\n");
}

grammar.append("\n");
grammar.append("WS: [ \\t\\r\\n]+ -> skip;\n");

startOffset = stopOffset + 2;
stopOffset = startOffset - 1;
output.append("[@").append(tokensCount).append(',').append(startOffset).append(':').append(stopOffset)
.append("='<EOF>',<-1>,").append(tokensCount + 1).append(":0]\n");

result.grammar = grammar.toString();
result.grammarName = grammarName;
result.input = input.toString();
result.output = output.toString();
return result;
}
}
56 changes: 26 additions & 30 deletions runtime/CSharp/src/Atn/ATNDeserializer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ public virtual ATN Deserialize(char[] data)
ReadRules (atn);
ReadModes (atn);
IList<IntervalSet> sets = new List<IntervalSet>();
ReadSets (atn, sets, ReadInt);
ReadSets (atn, sets, ReadInt32);
ReadSets (sets);
ReadEdges (atn, sets);
ReadDecisions (atn);
ReadLexerActions (atn);
Expand Down Expand Up @@ -188,18 +187,7 @@ protected internal virtual void ReadLexerActions(ATN atn)
atn.lexerActions = new ILexerAction[ReadInt()];
for (int i_10 = 0; i_10 < atn.lexerActions.Length; i_10++)
{
LexerActionType actionType = (LexerActionType)ReadInt();
int data1 = ReadInt();
if (data1 == unchecked((int)(0xFFFF)))
{
data1 = -1;
}
int data2 = ReadInt();
if (data2 == unchecked((int)(0xFFFF)))
{
data2 = -1;
}
ILexerAction lexerAction = LexerActionFactory(actionType, data1, data2);
ILexerAction lexerAction = LexerActionFactory((LexerActionType)ReadInt(), ReadInt(), ReadInt());
atn.lexerActions[i_10] = lexerAction;
}
}
Expand Down Expand Up @@ -309,7 +297,7 @@ protected internal virtual void ReadEdges(ATN atn, IList<IntervalSet> sets)
}
}

protected internal virtual void ReadSets(ATN atn, IList<IntervalSet> sets, System.Func<int> readUnicode)
protected virtual void ReadSets(IList<IntervalSet> sets)
{
//
// SETS
Expand All @@ -327,7 +315,7 @@ protected internal virtual void ReadSets(ATN atn, IList<IntervalSet> sets, Syste
}
for (int j = 0; j < nintervals; j++)
{
set.Add(readUnicode(), readUnicode());
set.Add(ReadInt(), ReadInt());
}
}
}
Expand Down Expand Up @@ -368,11 +356,7 @@ protected internal virtual void ReadRules(ATN atn)
RuleStartState startState = (RuleStartState)atn.states[s];
atn.ruleToStartState[i_5] = startState;
if (atn.grammarType == ATNType.Lexer) {
int tokenType = ReadInt ();
if (tokenType == unchecked((int)(0xFFFF))) {
tokenType = TokenConstants.EOF;
}
atn.ruleToTokenType [i_5] = tokenType;
atn.ruleToTokenType [i_5] = ReadInt();
}
}
atn.ruleToStopState = new RuleStopState[nrules];
Expand Down Expand Up @@ -406,10 +390,6 @@ protected internal virtual void ReadStates(ATN atn)
continue;
}
int ruleIndex = ReadInt();
if (ruleIndex == char.MaxValue)
{
ruleIndex = -1;
}
ATNState s = StateFactory(stype, ruleIndex);
if (stype == StateType.LoopEnd)
{
Expand Down Expand Up @@ -459,7 +439,7 @@ protected internal virtual ATN ReadATN()

protected internal virtual void CheckVersion()
{
int version = ReadInt();
int version = ReadUInt16();
if (version != SerializedVersion)
{
string reason = string.Format(CultureInfo.CurrentCulture, "Could not deserialize ATN with version {0} (expected {1}).", version, SerializedVersion);
Expand Down Expand Up @@ -962,14 +942,30 @@ private static bool TestTailCall(ATN atn, RuleTransition transition, bool optimi
}


protected internal int ReadInt()
private int ReadInt()
{
int value = ReadUInt16();
if (value == 0xFFFF)
{
return -1;
}

int mask = value >> 14 & 0b11;
return mask == 0
? value
: mask == 0b01
? (ReadUInt16() << 14) | (value & ((1 << 14) - 1))
: ReadInt32();
}

private int ReadUInt16()
{
return data[p++];
return data[p++];
}

protected internal int ReadInt32()
private int ReadInt32()
{
return (int)data[p++] | ((int)data[p++] << 16);
return data[p++] | (data[p++] << 16);
}

[return: NotNull]
Expand Down
59 changes: 59 additions & 0 deletions runtime/Java/src/org/antlr/v4/runtime/atn/ATNDataReader.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package org.antlr.v4.runtime.atn;

public class ATNDataReader {
private final static int JavaOptimizeOffset2 = 0xFFFF - ATNDataWriter.JavaOptimizeOffset + 1;

private final char[] data;
private int p;

public ATNDataReader(char[] data) {
this.data = data;
}

public int read() {
int value = readUInt16();
if (value == 0xFFFF) {
return -1;
}

int mask = value >> ATNDataWriter.MaskBits & 0b11;
return mask == 0
? value
: mask == 0b01
? (readUInt16() << ATNDataWriter.MaskBits) | (value & ((1 << ATNDataWriter.MaskBits) - 1))
: readInt32();
}

public int readInt32() {
return readUInt16() | (readUInt16() << 16);
}

public int readUInt16() {
return readUInt16(true);
}

public int readUInt16(boolean normalize) {
int result = data[p++];
// Each char value in data is shifted by +2 at the entry to this method.
// This is an encoding optimization targeting the serialized values 0
// and -1 (serialized to 0xFFFF), each of which are very common in the
// serialized form of the ATN. In the modified UTF-8 that Java uses for
// compiled string literals, these two character values have multi-byte
// forms. By shifting each value by +2, they become characters 2 and 1
// prior to writing the string, each of which have single-byte
// representations. Since the shift occurs in the tool during ATN
// serialization, each target is responsible for adjusting the values
// during deserialization.
//
// As a special case, note that the first element of data is not
// adjusted because it contains the major version number of the
// serialized ATN, which was fixed at 3 at the time the value shifting
// was implemented.
if (normalize) {
return result >= ATNDataWriter.JavaOptimizeOffset
? result - ATNDataWriter.JavaOptimizeOffset
: result + JavaOptimizeOffset2;
}
return result;
}
}
66 changes: 66 additions & 0 deletions runtime/Java/src/org/antlr/v4/runtime/atn/ATNDataWriter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package org.antlr.v4.runtime.atn;

import org.antlr.v4.runtime.misc.IntegerList;

public class ATNDataWriter {
public static final int MaskBits = 14;
public static final int JavaOptimizeOffset = 2;

private final IntegerList data;
private final boolean isJava;

public ATNDataWriter(IntegerList data, String language) {
this.data = data;
this.isJava = language.equals("Java");
}

/* Write int of full range [Integer.MIN_VALUE..Integer.MAX_VALUE] in compact format
| encoding | count | type |
| ----------------------------------------------------------- | ----- | ------------ |
| 00xx xxxx xxxx xxxx | 1 | int (14 bit) |
| 01xx xxxx xxxx xxxx xxxx xxxx xxxx xxxx | 2 | int (30 bit) |
| 1000 0000 0000 0000 xxxx xxxx xxxx xxxx xxxx xxxx xxxx xxxx | 3 | int (32 bit) |
| 1111 1111 1111 1111 | 1 | -1 (0xFFFF) |
*/
public int write(int value) {
if (value == -1) {
writeUInt16(0xFFFF);
return 1;
}

if (value >= 0) {
if (value < 1 << MaskBits) {
writeUInt16(value);
return 1;
}
else if (value < 1 << (MaskBits + 16)) {
writeUInt16(value & ((1 << MaskBits) - 1) | 0b01 << MaskBits);
writeUInt16(value >>> MaskBits);
return 2;
}
}

writeUInt16(0b10 << MaskBits);
writeInt32(value);
return 3;
}

public void writeInt32(int value) {
writeUInt16((char)value);
writeUInt16((char)(value >> 16));
}

public void writeUInt16(int value) {
writeUInt16(value, true);
}

public void writeUInt16(int value, boolean optimize) {
if (value < Character.MIN_VALUE || value > Character.MAX_VALUE) {
throw new UnsupportedOperationException("Serialized ATN data element "+
data.size() + " element " + value + " out of range "+
(int)Character.MIN_VALUE + ".." + (int)Character.MAX_VALUE);
}

data.add(isJava && optimize ? (value + JavaOptimizeOffset) & 0xFFFF : value);
}
}
Loading