Skip to content

Commit

Permalink
Implement semanticallyLess operator.
Browse files Browse the repository at this point in the history
  • Loading branch information
fruffy committed Jul 18, 2024
1 parent ad161b2 commit b27ddeb
Show file tree
Hide file tree
Showing 24 changed files with 574 additions and 53 deletions.
5 changes: 3 additions & 2 deletions backends/p4tools/common/lib/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@

#include <map>
#include <utility>
#include <vector>

#include <boost/container/flat_map.hpp>

#include "ir/ir.h"
#include "ir/node.h"
#include "ir/solver.h"
#include "ir/visitor.h"

namespace P4Tools {

/// Symbolic maps map a state variable to a IR::Expression.
using SymbolicMapType = boost::container::flat_map<IR::StateVariable, const IR::Expression *>;
using SymbolicMapType = boost::container::flat_map<IR::StateVariable, const IR::Expression *,
IR::IsSemanticallyLessComparator>;

/// Represents a solution found by the solver. A model is a concretized form of a symbolic
/// environment. All the expressions in a Model must be of type IR::Literal.
Expand Down
4 changes: 2 additions & 2 deletions backends/p4tools/modules/testgen/lib/concolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ ConcolicMethodImpls::ConcolicMethodImpls(const ImplList &implList) { add(implLis
bool ConcolicResolver::preorder(const IR::ConcolicVariable *var) {
auto concolicMethodName = var->concolicMethodName;
// Convert the concolic member variable to a state variable.
auto concolicReplacment = resolvedConcolicVariables.find(*var);
if (concolicReplacment == resolvedConcolicVariables.end()) {
auto concolicReplacement = resolvedConcolicVariables.find(*var);
if (concolicReplacement == resolvedConcolicVariables.end()) {
bool found = concolicMethodImpls.exec(concolicMethodName, var, state, evaluatedModel,
&resolvedConcolicVariables);
BUG_CHECK(found, "Unknown or unimplemented concolic method: %1%", concolicMethodName);
Expand Down
4 changes: 2 additions & 2 deletions backends/p4tools/modules/testgen/lib/concolic.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ namespace P4Tools::P4Testgen {
/// unique keys. Using this map, you can look up particular state variables and check whether
/// they actually are present, but not expressions. The reason expressions need to be keys is
/// that sometimes entire expressions are mapped to a particular constant.
using ConcolicVariableMap =
ordered_map<std::variant<IR::ConcolicVariable, const IR::Expression *>, const IR::Expression *>;
using ConcolicVariableMap = ordered_map<std::reference_wrapper<const IR::Expression>,
const IR::Expression *, IR::IsSemanticallyLessComparator>;

/// Encapsulates a set of concolic method implementations.
class ConcolicMethodImpls {
Expand Down
10 changes: 1 addition & 9 deletions backends/p4tools/modules/testgen/lib/final_state.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#include "backends/p4tools/modules/testgen/lib/final_state.h"

#include <list>
#include <utility>
#include <variant>
#include <vector>

#include <boost/container/vector.hpp>
Expand Down Expand Up @@ -78,13 +76,7 @@ std::optional<std::reference_wrapper<const FinalState>> FinalState::computeConco
const auto *concolicAssignment = resolvedConcolicVariable.second;
const IR::Expression *pathConstraint = nullptr;
// We need to differentiate between state variables and expressions here.
if (std::holds_alternative<IR::ConcolicVariable>(concolicVariable)) {
pathConstraint = new IR::Equ(std::get<IR::ConcolicVariable>(concolicVariable).clone(),
concolicAssignment);
} else if (std::holds_alternative<const IR::Expression *>(concolicVariable)) {
pathConstraint =
new IR::Equ(std::get<const IR::Expression *>(concolicVariable), concolicAssignment);
}
pathConstraint = new IR::Equ(&concolicVariable.get(), concolicAssignment);
CHECK_NULL(pathConstraint);
pathConstraint = state.get().getSymbolicEnv().subst(pathConstraint);
pathConstraint = P4::optimizeExpression(pathConstraint);
Expand Down
7 changes: 3 additions & 4 deletions backends/p4tools/modules/testgen/targets/bmv2/concolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ const ConcolicMethodImpls::ImplList Bmv2Concolic::BMV2_CONCOLIC_METHOD_IMPLS{
// Overwrite any previous assignment or result.
(*resolvedConcolicVariables)[*var] =
IR::Constant::get(checksumVarType, computedResult);

} else {
TESTGEN_UNIMPLEMENTED("Checksum output %1% of type %2% not supported", checksumVar,
checksumVar->type);
Expand All @@ -157,7 +156,7 @@ const ConcolicMethodImpls::ImplList Bmv2Concolic::BMV2_CONCOLIC_METHOD_IMPLS{
for (const auto &variable : resolvedExpressions) {
const auto *varName = variable.first;
const auto *varExpr = variable.second;
(*resolvedConcolicVariables)[varName] = varExpr;
(*resolvedConcolicVariables)[*varName] = varExpr;
}
}},
/* ======================================================================================
Expand Down Expand Up @@ -210,7 +209,7 @@ const ConcolicMethodImpls::ImplList Bmv2Concolic::BMV2_CONCOLIC_METHOD_IMPLS{
for (const auto &variable : resolvedExpressions) {
const auto *varName = variable.first;
const auto *varExpr = variable.second;
(*resolvedConcolicVariables)[varName] = varExpr;
(*resolvedConcolicVariables)[*varName] = varExpr;
}
}},

Expand Down Expand Up @@ -265,7 +264,7 @@ const ConcolicMethodImpls::ImplList Bmv2Concolic::BMV2_CONCOLIC_METHOD_IMPLS{
for (const auto &variable : resolvedExpressions) {
const auto *varName = variable.first;
const auto *varExpr = variable.second;
(*resolvedConcolicVariables)[varName] = varExpr;
(*resolvedConcolicVariables)[*varName] = varExpr;
}
}},
};
Expand Down
30 changes: 20 additions & 10 deletions backends/p4tools/p4tools.def
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ class StateVariable : Expression {
return *ref == *other.ref;
}

/// Implements comparisons so that StateVariables can be used as map keys.
bool operator<(const StateVariable &other) const {
// We use a custom compare function.
// TODO: Is there a faster way to implement this comparison?
return compare(ref, other.ref) < 0;
isSemanticallyLess {
if (static_cast<const Node *>(this) == &a_) return false;
if (typeId() != a_.typeId()) return typeId() < a_.typeId();
auto &a = static_cast<const StateVariable &>(a_);
return compare(ref, a.ref) < 0;
}

int compare(const Expression *e1, const Expression *e2) const {
Expand Down Expand Up @@ -135,17 +135,20 @@ class TaintExpression : Expression {
class SymbolicVariable : Expression {
#noconstructor

isSemanticallyLess {
if (static_cast<const Node *>(this) == &a_) return false;
if (typeId() != a_.typeId()) return typeId() < a_.typeId();
auto &a = static_cast<const SymbolicVariable &>(a_);
return label < a.label; /* ignore type */
}


/// The label of the symbolic variable.
cstring label;

/// A symbolic variable always has a type and no source info.
SymbolicVariable(Type type, cstring label) : Expression(type), label(label) {}

/// Implements comparisons so that SymbolicVariables can be used as map keys.
bool operator<(const SymbolicVariable &other) const {
return label < other.label;
}

toString { return "|" + label +"(" + type->toString() + ")|"; }

dbprint { out << "|" + label +"(" << type << ")|"; }
Expand Down Expand Up @@ -225,6 +228,13 @@ public:
out << "Concolic_" << label << "(" << arguments << ")";
}

isSemanticallyLess {
if (static_cast<const Node *>(this) == &a_) return false;
if (typeId() != a_.typeId()) return typeId() < a_.typeId();
auto &a = static_cast<const ConcolicVariable &>(a_);
return label < a.label; /* ignore type */
}

visit_children { v.visit(type, "type"); }

ConcolicVariable(const Type *type, cstring methodName,
Expand Down
4 changes: 4 additions & 0 deletions frontends/common/constantParsing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ std::ostream &operator<<(std::ostream &out, const UnparsedConstant &constant) {
return out;
}

bool operator<(const UnparsedConstant &a, const UnparsedConstant &b) {
return a.text < b.text || a.skip < b.skip || a.base < b.base || a.hasWidth < b.hasWidth;
}

/// A helper to parse constants which have an explicit width;
/// @see UnparsedConstant for an explanation of the parameters.
static IR::Constant *parseConstantWithWidth(Util::SourceInfo srcInfo, const char *text,
Expand Down
2 changes: 2 additions & 0 deletions frontends/common/constantParsing.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ struct UnparsedConstant {

std::ostream &operator<<(std::ostream &out, const UnparsedConstant &constant);

bool operator<(const UnparsedConstant &a, const UnparsedConstant &b);

/**
* Parses an UnparsedConstant @constant into an IR::Constant object, with
* location information taken from @srcInfo. If parsing fails, an IR::Constant
Expand Down
20 changes: 20 additions & 0 deletions ir/base.def
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ abstract Declaration : StatOrDecl, IDeclaration {
long declid = nextId++;
ID getName() const override { return name; }
equiv { return name == a.name; /* ignore declid */ }
isSemanticallyLess {
if (static_cast<const Node *>(this) == &a_) return false;
if (StatOrDecl::isSemanticallyLess(a_)) return true;
auto &a = static_cast<const Declaration &>(a_);
return name < a.name; /* ignore declid */
}
private:
static long nextId;
public:
Expand All @@ -151,6 +157,12 @@ abstract Type_Declaration : Type, IDeclaration {
long declid = nextId++;
ID getName() const override { return name; }
equiv { return name == a.name; /* ignore declid */ }
isSemanticallyLess {
if (static_cast<const Node *>(this) == &a_) return false;
if (Type::isSemanticallyLess(a_)) return true;
auto &a = static_cast<const Declaration &>(a_);
return name < a.name; /* ignore declid */
}
private:
static long nextId;
public:
Expand Down Expand Up @@ -218,6 +230,14 @@ class AnnotationToken {
cstring text;
optional NullOK UnparsedConstant* constInfo = nullptr;
dbprint { out << text; }
isSemanticallyLess {
if (static_cast<const Node *>(this) == &a_) return false;
if (typeId() != a_.typeId()) return typeId() < a_.typeId();
auto &a = static_cast<const AnnotationToken &>(a_);
return IR::isSemanticallyLess(token_type, a.token_type)
|| IR::isSemanticallyLess(text, a.text)
|| IR::isSemanticallyLess(*constInfo, *a.constInfo);
}
}

/// Annotations are used to provide additional information to the compiler
Expand Down
7 changes: 7 additions & 0 deletions ir/expression.def
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,13 @@ class Constant : Literal {
return Util::toString(value, width, sign, base);
}
visit_children { v.visit(type, "type"); }
isSemanticallyLess {
if (static_cast<const Node *>(this) == &a_) return false;
if (typeId() != a_.typeId()) return typeId() < a_.typeId();
if (!Literal::equiv(a_)) return Literal::isSemanticallyLess(a_);
auto &a = static_cast<const Constant &>(a_);
return value < a.value; /* ignore base */
}
}

class BoolLiteral : Literal {
Expand Down
3 changes: 3 additions & 0 deletions ir/id.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ struct ID : Util::IHasSourceInfo {
bool operator!=(cstring a) const { return name != a; }
bool operator==(const char *a) const { return name == a; }
bool operator!=(const char *a) const { return name != a; }
bool operator<(const ID &a) const { return name < a.name; }
bool operator<(const cstring &a) const { return name < a; }
bool operator<(const char *a) const { return name < a; }
explicit operator bool() const { return name; }
operator cstring() const { return name; }
std::string string() const { return name.string(); }
Expand Down
14 changes: 13 additions & 1 deletion ir/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,10 @@ class Node : public virtual INode {
virtual bool operator==(const Node &a) const { return this->typeId() == a.typeId(); }
/* 'equiv' does a deep-equals comparison, comparing all non-pointer fields and recursing
* though all Node subclass pointers to compare them with 'equiv' as well. */
virtual bool equiv(const Node &a) const { return this->typeId() == a.typeId(); }
[[nodiscard]] virtual bool equiv(const Node &a) const { return this->typeId() == a.typeId(); }
[[nodiscard]] virtual bool isSemanticallyLess(const Node &a) const {
return this->typeId() < a.typeId();
}
#define DEFINE_OPEQ_FUNC(CLASS, BASE) \
virtual bool operator==(const CLASS &) const { return false; }
IRNODE_ALL_SUBCLASSES(DEFINE_OPEQ_FUNC)
Expand All @@ -174,9 +177,18 @@ inline bool equal(const INode *a, const INode *b) {
return a == b || (a && b && *a->getNode() == *b->getNode());
}
inline bool equiv(const Node *a, const Node *b) { return a == b || (a && b && a->equiv(*b)); }

inline bool equiv(const INode *a, const INode *b) {
return a == b || (a && b && a->getNode()->equiv(*b->getNode()));
}
struct IsSemanticallyLessComparator {
bool operator()(const IR::Node *s1, const IR::Node *s2) const {
return s1->isSemanticallyLess(*s2);
}
bool operator()(const IR::Node &s1, const IR::Node &s2) const {
return s1.isSemanticallyLess(s2);
}
};
// NOLINTBEGIN(bugprone-macro-parentheses)
/* common things that ALL Node subclasses must define */
#define IRNODE_SUBCLASS(T) \
Expand Down
29 changes: 29 additions & 0 deletions ir/semantic_less.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#ifndef IR_SEMANTIC_LESS_H_
#define IR_SEMANTIC_LESS_H_

#include <type_traits>

namespace IR {

template <class T, typename std::enable_if<std::is_integral<T>::value, T>::type * = nullptr>
bool isSemanticallyLess(const T &a, const T &b) {
return a < b;
}

template <class T, typename std::enable_if<std::is_enum<T>::value, T>::type * = nullptr>
inline bool isSemanticallyLess(const T &a, const T &b) {
return a < b;
}

/// TODO: This also includes containers such as safe::vectors.
// These containers may use operator< for comparison.
// Should it be the responsibility of the IR node implementer to ensure they are implementing
// container comparison correctly?
template <class T, typename std::enable_if<std::is_class<T>::value, T>::type * = nullptr>
inline bool isSemanticallyLess(const T &a, const T &b) {
return a < b;
}

} // namespace IR

#endif // IR_SEMANTIC_LESS_H_
2 changes: 1 addition & 1 deletion ir/solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ using Constraint = IR::Expression;
/// Comparator to compare SymbolicVariable pointers.
struct SymbolicVarComp {
bool operator()(const IR::SymbolicVariable *s1, const IR::SymbolicVariable *s2) const {
return s1->operator<(*s2);
return s1->isSemanticallyLess(*s2);
}
};

Expand Down
14 changes: 14 additions & 0 deletions ir/type.def
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ enum class Direction {
InOut
};

inline bool isSemanticallyLess(const IR::Direction &a, const IR::Direction &b) {
return a < b;
}

inline cstring directionToString(IR::Direction direction) {
switch (direction) {
case IR::Direction::None:
Expand Down Expand Up @@ -87,6 +91,11 @@ class Type_Any : Type, ITypeVar {
(void)a; // silence unused warning
return true; /* ignore declid */
}
isSemanticallyLess {
if (static_cast<const Node *>(this) == &a_) return false;
if (Type::isSemanticallyLess(a_)) return true;
return typeId() < a_.typeId(); /* ignore declid */
}
}

/// This type is a fragment of another type.
Expand Down Expand Up @@ -267,6 +276,11 @@ class Type_InfInt : Type, ITypeVar {
(void)a; // silence unused warning
return true; /* ignore declid */
}
isSemanticallyLess {
if (static_cast<const Node *>(this) == &a_) return false;
if (Type::isSemanticallyLess(a_)) return true;
return typeId() < a_.typeId(); /* ignore declid */
}
const Type* getP4Type() const override { return this; }
}

Expand Down
14 changes: 14 additions & 0 deletions ir/v1.def
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,12 @@ class CalculatedField : IAnnotated {
ID name;
Expression cond;
update_or_verify() { } // FIXME -- needed by umpack_json(safe_vector) -- should not be
// update_or_verify is not an IR node, we have to implement our own comparator.
bool isSemanticallyLess(update_or_verify const & a) const {
return update < a.update
|| name < a.name
|| (cond != nullptr ? a.cond != nullptr ? cond->isSemanticallyLess(*a.cond) : false : a.cond != nullptr);
}
}
safe_vector<update_or_verify> specs = {};
Annotations annotations;
Expand All @@ -239,6 +245,14 @@ class CalculatedField : IAnnotated {
v.visit(field, "field");
for (auto &s : specs) v.visit(s.cond, s.name.name);
v.visit(annotations, "annotations"); }
isSemanticallyLess {
if (static_cast<const Node *>(this) == &a_) return false;
if (typeId() != a_.typeId()) return typeId() < a_.typeId();
auto &a = static_cast<const CalculatedField &>(a_);
return (field != nullptr ? a.field != nullptr ? field->isSemanticallyLess(*a.field) : false : a.field != nullptr)
|| std::lexicographical_compare(specs.begin(), specs.end(), a.specs.begin(), a.specs.end(), [](const update_or_verify &a, const update_or_verify &b) { return a.isSemanticallyLess(b); })
|| (annotations != nullptr ? a.annotations != nullptr ? annotations->isSemanticallyLess(*a.annotations) : false : a.annotations != nullptr);
}
}

class ParserValueSet : IAnnotated {
Expand Down
Loading

0 comments on commit b27ddeb

Please sign in to comment.