Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[C++] Cleanup ATNDeserializer and remove related deprecated methods from ATNSimulator #3545

Merged
merged 1 commit into from
Feb 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 108 additions & 110 deletions runtime/Cpp/runtime/src/atn/ATNDeserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,59 +61,125 @@ using namespace antlrcpp;

namespace {

uint32_t deserializeInt32(const std::vector<uint16_t>& data, size_t offset) {
return static_cast<uint32_t>(data[offset]) | (static_cast<uint32_t>(data[offset + 1]) << 16);
}
void checkCondition(bool condition, std::string_view message) {
if (!condition) {
throw IllegalStateException(std::string(message));
}
}

ssize_t readUnicodeInt(const std::vector<uint16_t>& data, int& p) {
return static_cast<ssize_t>(data[p++]);
}
void checkCondition(bool condition) {
checkCondition(condition, "");
}

ssize_t readUnicodeInt32(const std::vector<uint16_t>& data, int& p) {
auto result = deserializeInt32(data, p);
p += 2;
return static_cast<ssize_t>(result);
}
/**
* Analyze the {@link StarLoopEntryState} states in the specified ATN to set
* the {@link StarLoopEntryState#isPrecedenceDecision} field to the
* correct value.
*
* @param atn The ATN.
*/
void markPrecedenceDecisions(const ATN &atn) {
for (ATNState *state : atn.states) {
if (!is<StarLoopEntryState*>(state)) {
continue;
}

// We templatize this on the function type so the optimizer can inline
// the 16- or 32-bit readUnicodeInt/readUnicodeInt32 as needed.
template <typename F>
void deserializeSets(
const std::vector<uint16_t>& data,
int& p,
std::vector<misc::IntervalSet>& sets,
F readUnicode) {
size_t nsets = data[p++];
sets.reserve(sets.size() + nsets);
for (size_t i = 0; i < nsets; i++) {
size_t nintervals = data[p++];
misc::IntervalSet set;

bool containsEof = data[p++] != 0;
if (containsEof) {
set.add(-1);
/* We analyze the ATN to determine if this ATN decision state is the
* decision for the closure block that determines whether a
* precedence rule should continue or complete.
*/
if (atn.ruleToStartState[state->ruleIndex]->isLeftRecursiveRule) {
ATNState *maybeLoopEndState = state->transitions[state->transitions.size() - 1]->target;
if (is<LoopEndState *>(maybeLoopEndState)) {
if (maybeLoopEndState->epsilonOnlyTransitions && is<RuleStopState*>(maybeLoopEndState->transitions[0]->target)) {
downCast<StarLoopEntryState*>(state)->isPrecedenceDecision = true;
}
}
}
}
}

Ref<LexerAction> lexerActionFactory(LexerActionType type, int data1, int data2) {
switch (type) {
case LexerActionType::CHANNEL:
return std::make_shared<LexerChannelAction>(data1);

case LexerActionType::CUSTOM:
return std::make_shared<LexerCustomAction>(data1, data2);

case LexerActionType::MODE:
return std::make_shared< LexerModeAction>(data1);

for (size_t j = 0; j < nintervals; j++) {
auto a = readUnicode(data, p);
auto b = readUnicode(data, p);
set.add(a, b);
case LexerActionType::MORE:
return LexerMoreAction::getInstance();

case LexerActionType::POP_MODE:
return LexerPopModeAction::getInstance();

case LexerActionType::PUSH_MODE:
return std::make_shared<LexerPushModeAction>(data1);

case LexerActionType::SKIP:
return LexerSkipAction::getInstance();

case LexerActionType::TYPE:
return std::make_shared<LexerTypeAction>(data1);

default:
throw IllegalArgumentException("The specified lexer action type " + std::to_string(static_cast<size_t>(type)) +
" is not valid.");
}
sets.push_back(set);
}
}

}
uint32_t deserializeInt32(const std::vector<uint16_t>& data, size_t offset) {
return static_cast<uint32_t>(data[offset]) | (static_cast<uint32_t>(data[offset + 1]) << 16);
}

ATNDeserializer::ATNDeserializer(): ATNDeserializer(ATNDeserializationOptions::getDefaultOptions()) {
}
ssize_t readUnicodeInt(const std::vector<uint16_t>& data, int& p) {
return static_cast<ssize_t>(data[p++]);
}

ATNDeserializer::ATNDeserializer(const ATNDeserializationOptions& dso): _deserializationOptions(dso) {
}
ssize_t readUnicodeInt32(const std::vector<uint16_t>& data, int& p) {
auto result = deserializeInt32(data, p);
p += 2;
return static_cast<ssize_t>(result);
}

// We templatize this on the function type so the optimizer can inline
// the 16- or 32-bit readUnicodeInt/readUnicodeInt32 as needed.
template <typename F>
void deserializeSets(
const std::vector<uint16_t>& data,
int& p,
std::vector<misc::IntervalSet>& sets,
F readUnicode) {
size_t nsets = data[p++];
sets.reserve(sets.size() + nsets);
for (size_t i = 0; i < nsets; i++) {
size_t nintervals = data[p++];
misc::IntervalSet set;

bool containsEof = data[p++] != 0;
if (containsEof) {
set.add(-1);
}

for (size_t j = 0; j < nintervals; j++) {
auto a = readUnicode(data, p);
auto b = readUnicode(data, p);
set.add(a, b);
}
sets.push_back(set);
}
}

ATNDeserializer::~ATNDeserializer() {
}
ATN ATNDeserializer::deserialize(const std::vector<uint16_t>& data) {

ATNDeserializer::ATNDeserializer() : ATNDeserializer(ATNDeserializationOptions::getDefaultOptions()) {}

ATNDeserializer::ATNDeserializer(ATNDeserializationOptions deserializationOptions) : _deserializationOptions(std::move(deserializationOptions)) {}

ATN ATNDeserializer::deserialize(const std::vector<uint16_t>& data) const {
int p = 0;
int version = data[p++];
if (version != SERIALIZED_VERSION) {
Expand Down Expand Up @@ -448,35 +514,7 @@ ATN ATNDeserializer::deserialize(const std::vector<uint16_t>& data) {
return atn;
}

/**
* Analyze the {@link StarLoopEntryState} states in the specified ATN to set
* the {@link StarLoopEntryState#isPrecedenceDecision} field to the
* correct value.
*
* @param atn The ATN.
*/
void ATNDeserializer::markPrecedenceDecisions(const ATN &atn) const {
for (ATNState *state : atn.states) {
if (!is<StarLoopEntryState*>(state)) {
continue;
}

/* We analyze the ATN to determine if this ATN decision state is the
* decision for the closure block that determines whether a
* precedence rule should continue or complete.
*/
if (atn.ruleToStartState[state->ruleIndex]->isLeftRecursiveRule) {
ATNState *maybeLoopEndState = state->transitions[state->transitions.size() - 1]->target;
if (is<LoopEndState *>(maybeLoopEndState)) {
if (maybeLoopEndState->epsilonOnlyTransitions && is<RuleStopState*>(maybeLoopEndState->transitions[0]->target)) {
downCast<StarLoopEntryState*>(state)->isPrecedenceDecision = true;
}
}
}
}
}

void ATNDeserializer::verifyATN(const ATN &atn) {
void ATNDeserializer::verifyATN(const ATN &atn) const {
// verify assumptions
for (ATNState *state : atn.states) {
if (state == nullptr) {
Expand Down Expand Up @@ -535,16 +573,6 @@ void ATNDeserializer::verifyATN(const ATN &atn) {
}
}

void ATNDeserializer::checkCondition(bool condition) {
checkCondition(condition, "");
}

void ATNDeserializer::checkCondition(bool condition, const std::string &message) {
if (!condition) {
throw IllegalStateException(message);
}
}

ConstTransitionPtr ATNDeserializer::edgeFactory(const ATN &atn, size_t type, size_t /*src*/, size_t trg, size_t arg1,
size_t arg2, size_t arg3,
const std::vector<misc::IntervalSet> &sets) {
Expand Down Expand Up @@ -635,34 +663,4 @@ ATNState* ATNDeserializer::stateFactory(size_t type, size_t ruleIndex) {
return s;
}

Ref<LexerAction> ATNDeserializer::lexerActionFactory(LexerActionType type, int data1, int data2) const {
switch (type) {
case LexerActionType::CHANNEL:
return std::make_shared<LexerChannelAction>(data1);

case LexerActionType::CUSTOM:
return std::make_shared<LexerCustomAction>(data1, data2);

case LexerActionType::MODE:
return std::make_shared< LexerModeAction>(data1);

case LexerActionType::MORE:
return LexerMoreAction::getInstance();

case LexerActionType::POP_MODE:
return LexerPopModeAction::getInstance();

case LexerActionType::PUSH_MODE:
return std::make_shared<LexerPushModeAction>(data1);

case LexerActionType::SKIP:
return LexerSkipAction::getInstance();

case LexerActionType::TYPE:
return std::make_shared<LexerTypeAction>(data1);

default:
throw IllegalArgumentException("The specified lexer action type " + std::to_string(static_cast<size_t>(type)) +
" is not valid.");
}
}
37 changes: 14 additions & 23 deletions runtime/Cpp/runtime/src/atn/ATNDeserializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,32 @@

#pragma once

#include "atn/LexerAction.h"
#include "atn/ATNDeserializationOptions.h"
#include "atn/LexerAction.h"
#include "atn/Transition.h"

namespace antlr4 {
namespace atn {

class ANTLR4CPP_PUBLIC ATNDeserializer {
public:
static constexpr size_t SERIALIZED_VERSION = 4;

ATNDeserializer();

explicit ATNDeserializer(const ATNDeserializationOptions& dso);

virtual ~ATNDeserializer();
class ANTLR4CPP_PUBLIC ATNDeserializer final {
public:
static constexpr size_t SERIALIZED_VERSION = 4;

virtual ATN deserialize(const std::vector<uint16_t> &input);
virtual void verifyATN(const ATN &atn);
ATNDeserializer();

static void checkCondition(bool condition);
static void checkCondition(bool condition, const std::string &message);
explicit ATNDeserializer(ATNDeserializationOptions deserializationOptions);

static ConstTransitionPtr edgeFactory(const ATN &atn, size_t type, size_t src, size_t trg, size_t arg1, size_t arg2,
size_t arg3, const std::vector<misc::IntervalSet> &sets);
ATN deserialize(const std::vector<uint16_t> &input) const;
void verifyATN(const ATN &atn) const;

static ATNState *stateFactory(size_t type, size_t ruleIndex);
static ConstTransitionPtr edgeFactory(const ATN &atn, size_t type, size_t src, size_t trg, size_t arg1, size_t arg2,
size_t arg3, const std::vector<misc::IntervalSet> &sets);

protected:
void markPrecedenceDecisions(const ATN &atn) const;
Ref<LexerAction> lexerActionFactory(LexerActionType type, int data1, int data2) const;
static ATNState* stateFactory(size_t type, size_t ruleIndex);

private:
const ATNDeserializationOptions _deserializationOptions;
};
private:
const ATNDeserializationOptions _deserializationOptions;
};

} // namespace atn
} // namespace antlr4
38 changes: 6 additions & 32 deletions runtime/Cpp/runtime/src/atn/ATNSimulator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,24 @@
* can be found in the LICENSE.txt file in the project root.
*/

#include "atn/ATNType.h"
#include "atn/ATNSimulator.h"

#include "atn/ATNConfigSet.h"
#include "dfa/DFAState.h"
#include "atn/ATNDeserializer.h"
#include "atn/ATNType.h"
#include "atn/EmptyPredictionContext.h"

#include "atn/ATNSimulator.h"
#include "dfa/DFAState.h"

using namespace antlr4;
using namespace antlr4::dfa;
using namespace antlr4::atn;

const Ref<DFAState> ATNSimulator::ERROR = std::make_shared<DFAState>(INT32_MAX);
const Ref<DFAState> ATNSimulator::ERROR = std::make_shared<DFAState>(std::numeric_limits<int>::max());
std::shared_mutex ATNSimulator::_stateLock;
std::shared_mutex ATNSimulator::_edgeLock;

ATNSimulator::ATNSimulator(const ATN &atn, PredictionContextCache &sharedContextCache)
: atn(atn), _sharedContextCache(sharedContextCache) {
}

ATNSimulator::~ATNSimulator() {
}
: atn(atn), _sharedContextCache(sharedContextCache) {}

void ATNSimulator::clearDFA() {
throw UnsupportedOperationException("This ATN simulator does not support clearing the DFA.");
Expand All @@ -39,25 +35,3 @@ Ref<const PredictionContext> ATNSimulator::getCachedContext(Ref<const Prediction
std::map<Ref<const PredictionContext>, Ref<const PredictionContext>> visited;
return PredictionContext::getCachedContext(context, _sharedContextCache, visited);
}

ATN ATNSimulator::deserialize(const std::vector<uint16_t> &data) {
ATNDeserializer deserializer;
return deserializer.deserialize(data);
}

void ATNSimulator::checkCondition(bool condition) {
ATNDeserializer::checkCondition(condition);
}

void ATNSimulator::checkCondition(bool condition, const std::string &message) {
ATNDeserializer::checkCondition(condition, message);
}

ConstTransitionPtr ATNSimulator::edgeFactory(const ATN &atn, int type, int src, int trg, int arg1, int arg2, int arg3,
const std::vector<misc::IntervalSet> &sets) {
return ATNDeserializer::edgeFactory(atn, type, src, trg, arg1, arg2, arg3, sets);
}

ATNState *ATNSimulator::stateFactory(int type, int ruleIndex) {
return ATNDeserializer::stateFactory(type, ruleIndex);
}
Loading