diff --git a/backends/p4tools/common/lib/model.h b/backends/p4tools/common/lib/model.h index ffc63aef3ea..62d3b3db4bc 100644 --- a/backends/p4tools/common/lib/model.h +++ b/backends/p4tools/common/lib/model.h @@ -3,18 +3,19 @@ #include #include -#include #include #include "ir/ir.h" +#include "ir/node.h" #include "ir/solver.h" #include "ir/visitor.h" namespace P4Tools { /// Symbolic maps map a state variable to a IR::Expression. -using SymbolicMapType = boost::container::flat_map; +using SymbolicMapType = boost::container::flat_map; /// Represents a solution found by the solver. A model is a concretized form of a symbolic /// environment. All the expressions in a Model must be of type IR::Literal. diff --git a/backends/p4tools/modules/testgen/lib/concolic.cpp b/backends/p4tools/modules/testgen/lib/concolic.cpp index 5cbbd949340..4366698f41a 100644 --- a/backends/p4tools/modules/testgen/lib/concolic.cpp +++ b/backends/p4tools/modules/testgen/lib/concolic.cpp @@ -96,8 +96,8 @@ ConcolicMethodImpls::ConcolicMethodImpls(const ImplList &implList) { add(implLis bool ConcolicResolver::preorder(const IR::ConcolicVariable *var) { auto concolicMethodName = var->concolicMethodName; // Convert the concolic member variable to a state variable. - auto concolicReplacment = resolvedConcolicVariables.find(*var); - if (concolicReplacment == resolvedConcolicVariables.end()) { + auto concolicReplacement = resolvedConcolicVariables.find(*var); + if (concolicReplacement == resolvedConcolicVariables.end()) { bool found = concolicMethodImpls.exec(concolicMethodName, var, state, evaluatedModel, &resolvedConcolicVariables); BUG_CHECK(found, "Unknown or unimplemented concolic method: %1%", concolicMethodName); diff --git a/backends/p4tools/modules/testgen/lib/concolic.h b/backends/p4tools/modules/testgen/lib/concolic.h index f7040cab2b0..ba929c54f62 100644 --- a/backends/p4tools/modules/testgen/lib/concolic.h +++ b/backends/p4tools/modules/testgen/lib/concolic.h @@ -26,8 +26,8 @@ namespace P4Tools::P4Testgen { /// unique keys. Using this map, you can look up particular state variables and check whether /// they actually are present, but not expressions. The reason expressions need to be keys is /// that sometimes entire expressions are mapped to a particular constant. -using ConcolicVariableMap = - ordered_map, const IR::Expression *>; +using ConcolicVariableMap = ordered_map, + const IR::Expression *, IR::IsSemanticallyLessComparator>; /// Encapsulates a set of concolic method implementations. class ConcolicMethodImpls { diff --git a/backends/p4tools/modules/testgen/lib/final_state.cpp b/backends/p4tools/modules/testgen/lib/final_state.cpp index a4eb7436ee1..b903da01423 100644 --- a/backends/p4tools/modules/testgen/lib/final_state.cpp +++ b/backends/p4tools/modules/testgen/lib/final_state.cpp @@ -1,8 +1,6 @@ #include "backends/p4tools/modules/testgen/lib/final_state.h" -#include #include -#include #include #include @@ -78,13 +76,7 @@ std::optional> FinalState::computeConco const auto *concolicAssignment = resolvedConcolicVariable.second; const IR::Expression *pathConstraint = nullptr; // We need to differentiate between state variables and expressions here. - if (std::holds_alternative(concolicVariable)) { - pathConstraint = new IR::Equ(std::get(concolicVariable).clone(), - concolicAssignment); - } else if (std::holds_alternative(concolicVariable)) { - pathConstraint = - new IR::Equ(std::get(concolicVariable), concolicAssignment); - } + pathConstraint = new IR::Equ(&concolicVariable.get(), concolicAssignment); CHECK_NULL(pathConstraint); pathConstraint = state.get().getSymbolicEnv().subst(pathConstraint); pathConstraint = P4::optimizeExpression(pathConstraint); diff --git a/backends/p4tools/modules/testgen/targets/bmv2/concolic.cpp b/backends/p4tools/modules/testgen/targets/bmv2/concolic.cpp index 2c37bf3e09c..84851164c3b 100644 --- a/backends/p4tools/modules/testgen/targets/bmv2/concolic.cpp +++ b/backends/p4tools/modules/testgen/targets/bmv2/concolic.cpp @@ -145,7 +145,6 @@ const ConcolicMethodImpls::ImplList Bmv2Concolic::BMV2_CONCOLIC_METHOD_IMPLS{ // Overwrite any previous assignment or result. (*resolvedConcolicVariables)[*var] = IR::Constant::get(checksumVarType, computedResult); - } else { TESTGEN_UNIMPLEMENTED("Checksum output %1% of type %2% not supported", checksumVar, checksumVar->type); @@ -157,7 +156,7 @@ const ConcolicMethodImpls::ImplList Bmv2Concolic::BMV2_CONCOLIC_METHOD_IMPLS{ for (const auto &variable : resolvedExpressions) { const auto *varName = variable.first; const auto *varExpr = variable.second; - (*resolvedConcolicVariables)[varName] = varExpr; + (*resolvedConcolicVariables)[*varName] = varExpr; } }}, /* ====================================================================================== @@ -210,7 +209,7 @@ const ConcolicMethodImpls::ImplList Bmv2Concolic::BMV2_CONCOLIC_METHOD_IMPLS{ for (const auto &variable : resolvedExpressions) { const auto *varName = variable.first; const auto *varExpr = variable.second; - (*resolvedConcolicVariables)[varName] = varExpr; + (*resolvedConcolicVariables)[*varName] = varExpr; } }}, @@ -265,7 +264,7 @@ const ConcolicMethodImpls::ImplList Bmv2Concolic::BMV2_CONCOLIC_METHOD_IMPLS{ for (const auto &variable : resolvedExpressions) { const auto *varName = variable.first; const auto *varExpr = variable.second; - (*resolvedConcolicVariables)[varName] = varExpr; + (*resolvedConcolicVariables)[*varName] = varExpr; } }}, }; diff --git a/backends/p4tools/p4tools.def b/backends/p4tools/p4tools.def index 44d55d1db61..7e6d1e0a769 100644 --- a/backends/p4tools/p4tools.def +++ b/backends/p4tools/p4tools.def @@ -28,11 +28,11 @@ class StateVariable : Expression { return *ref == *other.ref; } - /// Implements comparisons so that StateVariables can be used as map keys. - bool operator<(const StateVariable &other) const { - // We use a custom compare function. - // TODO: Is there a faster way to implement this comparison? - return compare(ref, other.ref) < 0; + isSemanticallyLess { + if (static_cast(this) == &a_) return false; + if (typeId() != a_.typeId()) return typeId() < a_.typeId(); + auto &a = static_cast(a_); + return compare(ref, a.ref) < 0; } int compare(const Expression *e1, const Expression *e2) const { @@ -135,17 +135,20 @@ class TaintExpression : Expression { class SymbolicVariable : Expression { #noconstructor + isSemanticallyLess { + if (static_cast(this) == &a_) return false; + if (typeId() != a_.typeId()) return typeId() < a_.typeId(); + auto &a = static_cast(a_); + return label < a.label; /* ignore type */ + } + + /// The label of the symbolic variable. cstring label; /// A symbolic variable always has a type and no source info. SymbolicVariable(Type type, cstring label) : Expression(type), label(label) {} - /// Implements comparisons so that SymbolicVariables can be used as map keys. - bool operator<(const SymbolicVariable &other) const { - return label < other.label; - } - toString { return "|" + label +"(" + type->toString() + ")|"; } dbprint { out << "|" + label +"(" << type << ")|"; } @@ -225,6 +228,13 @@ public: out << "Concolic_" << label << "(" << arguments << ")"; } + isSemanticallyLess { + if (static_cast(this) == &a_) return false; + if (typeId() != a_.typeId()) return typeId() < a_.typeId(); + auto &a = static_cast(a_); + return label < a.label; /* ignore type */ + } + visit_children { v.visit(type, "type"); } ConcolicVariable(const Type *type, cstring methodName, diff --git a/frontends/common/constantParsing.cpp b/frontends/common/constantParsing.cpp index 5851a0eed6c..6823b3042ce 100644 --- a/frontends/common/constantParsing.cpp +++ b/frontends/common/constantParsing.cpp @@ -27,6 +27,10 @@ std::ostream &operator<<(std::ostream &out, const UnparsedConstant &constant) { return out; } +bool operator<(const UnparsedConstant &a, const UnparsedConstant &b) { + return a.text < b.text || a.skip < b.skip || a.base < b.base || a.hasWidth < b.hasWidth; +} + /// A helper to parse constants which have an explicit width; /// @see UnparsedConstant for an explanation of the parameters. static IR::Constant *parseConstantWithWidth(Util::SourceInfo srcInfo, const char *text, diff --git a/frontends/common/constantParsing.h b/frontends/common/constantParsing.h index 00b34111f06..bb9621c5672 100644 --- a/frontends/common/constantParsing.h +++ b/frontends/common/constantParsing.h @@ -69,6 +69,8 @@ struct UnparsedConstant { std::ostream &operator<<(std::ostream &out, const UnparsedConstant &constant); +bool operator<(const UnparsedConstant &a, const UnparsedConstant &b); + /** * Parses an UnparsedConstant @constant into an IR::Constant object, with * location information taken from @srcInfo. If parsing fails, an IR::Constant diff --git a/ir/base.def b/ir/base.def index a1e6d3f29e1..5435397cfd1 100644 --- a/ir/base.def +++ b/ir/base.def @@ -135,6 +135,12 @@ abstract Declaration : StatOrDecl, IDeclaration { long declid = nextId++; ID getName() const override { return name; } equiv { return name == a.name; /* ignore declid */ } + isSemanticallyLess { + if (static_cast(this) == &a_) return false; + if (StatOrDecl::isSemanticallyLess(a_)) return true; + auto &a = static_cast(a_); + return name < a.name; /* ignore declid */ + } private: static long nextId; public: @@ -151,6 +157,12 @@ abstract Type_Declaration : Type, IDeclaration { long declid = nextId++; ID getName() const override { return name; } equiv { return name == a.name; /* ignore declid */ } + isSemanticallyLess { + if (static_cast(this) == &a_) return false; + if (Type::isSemanticallyLess(a_)) return true; + auto &a = static_cast(a_); + return name < a.name; /* ignore declid */ + } private: static long nextId; public: @@ -218,6 +230,14 @@ class AnnotationToken { cstring text; optional NullOK UnparsedConstant* constInfo = nullptr; dbprint { out << text; } + isSemanticallyLess { + if (static_cast(this) == &a_) return false; + if (typeId() != a_.typeId()) return typeId() < a_.typeId(); + auto &a = static_cast(a_); + return IR::isSemanticallyLess(token_type, a.token_type) + || IR::isSemanticallyLess(text, a.text) + || IR::isSemanticallyLess(*constInfo, *a.constInfo); + } } /// Annotations are used to provide additional information to the compiler diff --git a/ir/expression.def b/ir/expression.def index c6b61cd17a2..4390f6c894b 100644 --- a/ir/expression.def +++ b/ir/expression.def @@ -279,6 +279,13 @@ class Constant : Literal { return Util::toString(value, width, sign, base); } visit_children { v.visit(type, "type"); } + isSemanticallyLess { + if (static_cast(this) == &a_) return false; + if (typeId() != a_.typeId()) return typeId() < a_.typeId(); + if (!Literal::equiv(a_)) return Literal::isSemanticallyLess(a_); + auto &a = static_cast(a_); + return value < a.value; /* ignore base */ + } } class BoolLiteral : Literal { diff --git a/ir/id.h b/ir/id.h index 22eced72488..89e7e7e157e 100644 --- a/ir/id.h +++ b/ir/id.h @@ -52,6 +52,9 @@ struct ID : Util::IHasSourceInfo { bool operator!=(cstring a) const { return name != a; } bool operator==(const char *a) const { return name == a; } bool operator!=(const char *a) const { return name != a; } + bool operator<(const ID &a) const { return name < a.name; } + bool operator<(const cstring &a) const { return name < a; } + bool operator<(const char *a) const { return name < a; } explicit operator bool() const { return name; } operator cstring() const { return name; } std::string string() const { return name.string(); } diff --git a/ir/node.h b/ir/node.h index 93b1ebb1da4..41c910e6085 100644 --- a/ir/node.h +++ b/ir/node.h @@ -153,7 +153,10 @@ class Node : public virtual INode { virtual bool operator==(const Node &a) const { return this->typeId() == a.typeId(); } /* 'equiv' does a deep-equals comparison, comparing all non-pointer fields and recursing * though all Node subclass pointers to compare them with 'equiv' as well. */ - virtual bool equiv(const Node &a) const { return this->typeId() == a.typeId(); } + [[nodiscard]] virtual bool equiv(const Node &a) const { return this->typeId() == a.typeId(); } + [[nodiscard]] virtual bool isSemanticallyLess(const Node &a) const { + return this->typeId() < a.typeId(); + } #define DEFINE_OPEQ_FUNC(CLASS, BASE) \ virtual bool operator==(const CLASS &) const { return false; } IRNODE_ALL_SUBCLASSES(DEFINE_OPEQ_FUNC) @@ -174,9 +177,18 @@ inline bool equal(const INode *a, const INode *b) { return a == b || (a && b && *a->getNode() == *b->getNode()); } inline bool equiv(const Node *a, const Node *b) { return a == b || (a && b && a->equiv(*b)); } + inline bool equiv(const INode *a, const INode *b) { return a == b || (a && b && a->getNode()->equiv(*b->getNode())); } +struct IsSemanticallyLessComparator { + bool operator()(const IR::Node *s1, const IR::Node *s2) const { + return s1->isSemanticallyLess(*s2); + } + bool operator()(const IR::Node &s1, const IR::Node &s2) const { + return s1.isSemanticallyLess(s2); + } +}; // NOLINTBEGIN(bugprone-macro-parentheses) /* common things that ALL Node subclasses must define */ #define IRNODE_SUBCLASS(T) \ diff --git a/ir/semantic_less.h b/ir/semantic_less.h new file mode 100644 index 00000000000..aa6eeea5d83 --- /dev/null +++ b/ir/semantic_less.h @@ -0,0 +1,29 @@ +#ifndef IR_SEMANTIC_LESS_H_ +#define IR_SEMANTIC_LESS_H_ + +#include + +namespace IR { + +template ::value, T>::type * = nullptr> +bool isSemanticallyLess(const T &a, const T &b) { + return a < b; +} + +template ::value, T>::type * = nullptr> +inline bool isSemanticallyLess(const T &a, const T &b) { + return a < b; +} + +/// TODO: This also includes containers such as safe::vectors. +// These containers may use operator< for comparison. +// Should it be the responsibility of the IR node implementer to ensure they are implementing +// container comparison correctly? +template ::value, T>::type * = nullptr> +inline bool isSemanticallyLess(const T &a, const T &b) { + return a < b; +} + +} // namespace IR + +#endif // IR_SEMANTIC_LESS_H_ diff --git a/ir/solver.h b/ir/solver.h index 459be62fa28..6343879e1d7 100644 --- a/ir/solver.h +++ b/ir/solver.h @@ -18,7 +18,7 @@ using Constraint = IR::Expression; /// Comparator to compare SymbolicVariable pointers. struct SymbolicVarComp { bool operator()(const IR::SymbolicVariable *s1, const IR::SymbolicVariable *s2) const { - return s1->operator<(*s2); + return s1->isSemanticallyLess(*s2); } }; diff --git a/ir/type.def b/ir/type.def index 1bc9920f13c..942cf88323e 100644 --- a/ir/type.def +++ b/ir/type.def @@ -17,6 +17,10 @@ enum class Direction { InOut }; +inline bool isSemanticallyLess(const IR::Direction &a, const IR::Direction &b) { + return a < b; +} + inline cstring directionToString(IR::Direction direction) { switch (direction) { case IR::Direction::None: @@ -87,6 +91,11 @@ class Type_Any : Type, ITypeVar { (void)a; // silence unused warning return true; /* ignore declid */ } + isSemanticallyLess { + if (static_cast(this) == &a_) return false; + if (Type::isSemanticallyLess(a_)) return true; + return typeId() < a_.typeId(); /* ignore declid */ + } } /// This type is a fragment of another type. @@ -267,6 +276,11 @@ class Type_InfInt : Type, ITypeVar { (void)a; // silence unused warning return true; /* ignore declid */ } + isSemanticallyLess { + if (static_cast(this) == &a_) return false; + if (Type::isSemanticallyLess(a_)) return true; + return typeId() < a_.typeId(); /* ignore declid */ + } const Type* getP4Type() const override { return this; } } diff --git a/ir/v1.def b/ir/v1.def index 0cf0583ee16..d4162081b9f 100644 --- a/ir/v1.def +++ b/ir/v1.def @@ -231,6 +231,12 @@ class CalculatedField : IAnnotated { ID name; Expression cond; update_or_verify() { } // FIXME -- needed by umpack_json(safe_vector) -- should not be + // update_or_verify is not an IR node, we have to implement our own comparator. + bool isSemanticallyLess(update_or_verify const & a) const { + return update < a.update + || name < a.name + || (cond != nullptr ? a.cond != nullptr ? cond->isSemanticallyLess(*a.cond) : false : a.cond != nullptr); + } } safe_vector specs = {}; Annotations annotations; @@ -239,6 +245,14 @@ class CalculatedField : IAnnotated { v.visit(field, "field"); for (auto &s : specs) v.visit(s.cond, s.name.name); v.visit(annotations, "annotations"); } + isSemanticallyLess { + if (static_cast(this) == &a_) return false; + if (typeId() != a_.typeId()) return typeId() < a_.typeId(); + auto &a = static_cast(a_); + return (field != nullptr ? a.field != nullptr ? field->isSemanticallyLess(*a.field) : false : a.field != nullptr) + || std::lexicographical_compare(specs.begin(), specs.end(), a.specs.begin(), a.specs.end(), [](const update_or_verify &a, const update_or_verify &b) { return a.isSemanticallyLess(b); }) + || (annotations != nullptr ? a.annotations != nullptr ? annotations->isSemanticallyLess(*a.annotations) : false : a.annotations != nullptr); + } } class ParserValueSet : IAnnotated { diff --git a/ir/vector.h b/ir/vector.h index 9e18b048ad4..69cb8aeab30 100644 --- a/ir/vector.h +++ b/ir/vector.h @@ -186,12 +186,29 @@ class Vector : public VectorBase { if (static_cast(this) == &a_) return true; if (this->typeId() != a_.typeId()) return false; auto &a = static_cast &>(a_); - if (size() != a.size()) return false; + if (size() != a.size()) { + return false; + } auto it = a.begin(); - for (auto *el : *this) - if (!el->equiv(**it++)) return false; + for (auto *el : *this) { + if (!el->equiv(**it++)) { + return false; + } + } return true; } + bool isSemanticallyLess(const Node &a_) const override { + if (static_cast(this) == &a_) { + return false; + } + if (VectorBase::isSemanticallyLess(a_)) { + return true; + } + auto &a = static_cast &>(a_); + return std::lexicographical_compare( + vec.begin(), vec.end(), a.vec.begin(), a.vec.end(), + [](const T *a, const T *b) { return a->isSemanticallyLess(*b); }); + } cstring node_type_name() const override { return "Vector<" + T::static_type_name() + ">"; } static cstring static_type_name() { return "Vector<" + T::static_type_name() + ">"; } void visit_children(Visitor &v) override; diff --git a/lib/ordered_map.h b/lib/ordered_map.h index 0936151f8f5..f03202cd6ec 100644 --- a/lib/ordered_map.h +++ b/lib/ordered_map.h @@ -124,6 +124,20 @@ class ordered_map { size_type max_size() const noexcept { return data_map.max_size(); } bool operator==(const ordered_map &a) const { return data == a.data; } bool operator!=(const ordered_map &a) const { return data != a.data; } + bool operator<(const ordered_map &a) const { + // we define this to work INDEPENDENT of the order -- so it is possible to have + // two ordered_maps where !(a < b) && !(b < a) && !(a == b) -- such sets have the + // same elements but in a different order. This is generally what you want if you + // have a set of ordered_maps (or use ordered_map as a map key). + auto it = a.data_map.begin(); + for (auto &el : data_map) { + if (it == a.data_map.end()) return false; + if (mapcmp()(el.first, it->first)) return true; + if (mapcmp()(it->first, el.first)) return false; + ++it; + } + return it != a.data_map.end(); + } void clear() { data.clear(); data_map.clear(); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 19908274d06..ec621976918 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -43,6 +43,7 @@ set (GTEST_UNITTEST_SOURCES gtest/midend_test.cpp gtest/frontend_test.cpp gtest/opeq_test.cpp + gtest/semantically_less_test.cpp gtest/ordered_map.cpp gtest/ordered_set.cpp gtest/parser_unroll.cpp diff --git a/test/gtest/semantically_less_test.cpp b/test/gtest/semantically_less_test.cpp new file mode 100644 index 00000000000..45d17f96798 --- /dev/null +++ b/test/gtest/semantically_less_test.cpp @@ -0,0 +1,313 @@ +/* +Copyright 2024-present New York University. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include + +#include "ir/ir.h" +#include "ir/node.h" +#include "ir/visitor.h" + +static inline bool isSemanticallyLess(const IR::Node &a, const IR::Node &b) { + return a.isSemanticallyLess(b); +} + +#define EXECUTE_FUNCTION_FOR_P4C_NODE(BASE_TYPE, a, b, function) \ + EXPECT_TRUE(function(*(a), *(b))); \ + EXPECT_TRUE(function(*(a), *static_cast(b))); \ + EXPECT_TRUE(function(*(a), *static_cast(b))); \ + EXPECT_TRUE(function(*static_cast(a), *(b))); \ + EXPECT_TRUE(function(*static_cast(a), *static_cast(b))); \ + EXPECT_TRUE(function(*static_cast(a), *static_cast(b))); \ + EXPECT_TRUE(function(*static_cast(a), *(b))); \ + EXPECT_TRUE(function(*static_cast(a), *static_cast(b))); \ + EXPECT_TRUE(function(*static_cast(a), *static_cast(b))); + +#define CHECK_LESS_FOR_P4C_NODE(BASE_TYPE, a, b) \ + EXECUTE_FUNCTION_FOR_P4C_NODE(BASE_TYPE, a, b, isSemanticallyLess) +#define CHECK_LESS_FOR_P4C_TYPE(a, b) CHECK_LESS_FOR_P4C_NODE(IR::Type, a, b) +#define CHECK_LESS_FOR_P4C_EXPRESSION(a, b) CHECK_LESS_FOR_P4C_NODE(IR::Expression, a, b) +#define CHECK_LESS_FOR_P4C_NODE_VECTOR(a, b) CHECK_LESS_FOR_P4C_NODE(IR::Vector, a, b) +#define CHECK_LESS_FOR_P4C_DECLARATION(a, b) CHECK_LESS_FOR_P4C_NODE(IR::Declaration, a, b) + +#define CHECK_GREATER_FOR_P4C_TYPE(a, b) CHECK_LESS_FOR_P4C_TYPE(b, a) +#define CHECK_GREATER_FOR_P4C_EXPRESSION(a, b) CHECK_LESS_FOR_P4C_EXPRESSION(b, a) +#define CHECK_GREATER_FOR_P4C_NODE_VECTOR(a, b) CHECK_LESS_FOR_P4C_NODE_VECTOR(b, a) +#define CHECK_GREATER_FOR_P4C_DECLARATION(a, b) CHECK_LESS_FOR_P4C_DECLARATION(b, a) + +bool checkEqualityWithLess(const IR::Node &a, const IR::Node &b) { + return !(isSemanticallyLess(a, b) || isSemanticallyLess(b, a)); +} + +#define CHECK_EQUALITY_WITH_LESS(BASE_TYPE, a, b) \ + EXECUTE_FUNCTION_FOR_P4C_NODE(BASE_TYPE, a, b, checkEqualityWithLess) +#define CHECK_EQUALITY_FOR_P4C_TYPE(a, b) CHECK_EQUALITY_WITH_LESS(IR::Type, a, b) +#define CHECK_EQUALITY_FOR_P4C_EXPRESSION(a, b) CHECK_EQUALITY_WITH_LESS(IR::Expression, a, b) +#define CHECK_EQUALITY_FOR_P4C_NODE_VECTOR(a, b) \ + CHECK_EQUALITY_WITH_LESS(IR::Vector, a, b) +#define CHECK_EQUALITY_FOR_P4C_DECLARATION(a, b) CHECK_EQUALITY_WITH_LESS(IR::Declaration, a, b) + +TEST(OperatorLess, Types) { + const auto *a = IR::Type_Bits::get(16, false); + const auto *b = IR::Type_Bits::get(32, false); + const auto *c = IR::Type_Bits::get(16, false); + + CHECK_LESS_FOR_P4C_TYPE(a, b); + CHECK_GREATER_FOR_P4C_TYPE(b, a); + CHECK_EQUALITY_FOR_P4C_TYPE(a, c); + + const auto *d = IR::Type_Boolean::get(); + CHECK_LESS_FOR_P4C_TYPE(d, a); + CHECK_GREATER_FOR_P4C_TYPE(a, d); + + const auto *e = IR::Type_InfInt::get(); + CHECK_LESS_FOR_P4C_TYPE(a, e); + CHECK_GREATER_FOR_P4C_TYPE(e, a); + + const auto *f = new IR::Type_Name("f"); + const auto *f2 = new IR::Type_Name("f"); + CHECK_LESS_FOR_P4C_TYPE(a, f); + CHECK_GREATER_FOR_P4C_TYPE(f, a); + CHECK_EQUALITY_FOR_P4C_TYPE(f, f2) + + const auto *g = IR::Type_String::get(); + CHECK_LESS_FOR_P4C_TYPE(a, g); + CHECK_GREATER_FOR_P4C_TYPE(g, a); +} + +TEST(OperatorLess, Constants) { + // Check unsigned constants. + { + const auto *a = new IR::Constant(IR::Type_Bits::get(16, false), 5); + const auto *b = new IR::Constant(IR::Type_Bits::get(16, false), 10); + const auto *c = new IR::Constant(IR::Type_Bits::get(16, false), 5); + + CHECK_LESS_FOR_P4C_EXPRESSION(a, b); + CHECK_GREATER_FOR_P4C_EXPRESSION(b, a); + CHECK_EQUALITY_FOR_P4C_EXPRESSION(a, c); + } + // Check signed constants. + { + const auto *a = new IR::Constant(IR::Type_Bits::get(16, true), -1); + const auto *b = new IR::Constant(IR::Type_Bits::get(16, true), 0); + const auto *c = new IR::Constant(IR::Type_Bits::get(16, true), -1); + + CHECK_LESS_FOR_P4C_EXPRESSION(a, b); + CHECK_EQUALITY_FOR_P4C_EXPRESSION(a, c); + } + // Check Strings + { + const auto *a = new IR::StringLiteral(IR::Type_String::get(), cstring("a")); + const auto *b = new IR::StringLiteral(IR::Type_String::get(), cstring("b")); + const auto *a2 = new IR::StringLiteral(IR::Type_String::get(), cstring("a")); + + CHECK_LESS_FOR_P4C_EXPRESSION(a, b); + CHECK_GREATER_FOR_P4C_EXPRESSION(b, a); + CHECK_EQUALITY_FOR_P4C_EXPRESSION(a, a2); + } + // Check booleans + { + const auto *a = new IR::BoolLiteral(false); + const auto *b = new IR::BoolLiteral(true); + const auto *c = new IR::BoolLiteral(false); + + CHECK_LESS_FOR_P4C_EXPRESSION(a, b); + CHECK_GREATER_FOR_P4C_EXPRESSION(b, a); + CHECK_EQUALITY_FOR_P4C_EXPRESSION(a, c); + } + // Check infinite precision integers. + { + const auto *a = new IR::Constant(IR::Type_InfInt::get(), 5); + const auto *b = new IR::Constant(IR::Type_InfInt::get(), 10); + const auto *c = new IR::Constant(IR::Type_InfInt::get(), 5); + + CHECK_LESS_FOR_P4C_EXPRESSION(a, b); + CHECK_GREATER_FOR_P4C_EXPRESSION(b, a); + CHECK_EQUALITY_FOR_P4C_EXPRESSION(a, c); + } +} + +TEST(OperatorLess, MixedConstants) { + const auto *t = IR::Type_Bits::get(16, false); + const auto *a = new IR::StringLiteral(IR::Type_String::get(), cstring("a")); + const auto *b = new IR::Constant(t, 0); + const auto *c = new IR::Constant(t, 10); + const auto *d = new IR::BoolLiteral(false); + + CHECK_LESS_FOR_P4C_EXPRESSION(b, a); + CHECK_LESS_FOR_P4C_EXPRESSION(c, a); + CHECK_GREATER_FOR_P4C_EXPRESSION(a, b); + CHECK_GREATER_FOR_P4C_EXPRESSION(a, c); + + CHECK_LESS_FOR_P4C_EXPRESSION(d, a); + CHECK_GREATER_FOR_P4C_EXPRESSION(a, d); + CHECK_LESS_FOR_P4C_EXPRESSION(b, d); + CHECK_GREATER_FOR_P4C_EXPRESSION(d, b); +} + +TEST(OperatorLess, ConstantVectors) { + // Two empty vectors are equal. + const auto *e1 = new IR::IndexedVector(); + const auto *e2 = new IR::IndexedVector(); + CHECK_EQUALITY_FOR_P4C_NODE_VECTOR(e1, e2); + + const auto *t = IR::Type_Bits::get(16, false); + const auto *a = new IR::Constant(t, 5); + const auto *b = new IR::Constant(t, 10); + const auto *c = new IR::Constant(t, 5); + + auto *p1 = new IR::IndexedVector(a); + auto *p2 = new IR::IndexedVector(b); + auto *p3 = new IR::IndexedVector(c); + + CHECK_LESS_FOR_P4C_NODE_VECTOR(p1, p2); + CHECK_GREATER_FOR_P4C_NODE_VECTOR(p2, p1); + CHECK_EQUALITY_FOR_P4C_NODE_VECTOR(p1, p3); + + // Check that we correctly compare uneven vectors. + p2->push_back(b); + CHECK_LESS_FOR_P4C_NODE_VECTOR(p1, p2); + CHECK_GREATER_FOR_P4C_NODE_VECTOR(p2, p1); + p1->push_back(b); + p1->push_back(b); + CHECK_LESS_FOR_P4C_NODE_VECTOR(p1, p2); + CHECK_GREATER_FOR_P4C_NODE_VECTOR(p2, p1); + + // Check that we correctly compare vectors with different nodes. + const auto *p4 = new IR::IndexedVector(new IR::BoolLiteral(false)); + const auto *p5 = new IR::IndexedVector(new IR::Constant(t, 0)); + const auto *p6 = new IR::IndexedVector( + new IR::StringLiteral(IR::Type_String::get(), cstring("a"))); + + CHECK_LESS_FOR_P4C_NODE_VECTOR(p5, p4); + CHECK_GREATER_FOR_P4C_NODE_VECTOR(p4, p5); + CHECK_LESS_FOR_P4C_NODE_VECTOR(p4, p6); + CHECK_GREATER_FOR_P4C_NODE_VECTOR(p6, p4); + CHECK_LESS_FOR_P4C_NODE_VECTOR(p5, p6); + CHECK_GREATER_FOR_P4C_NODE_VECTOR(p6, p5); +} + +TEST(OperatorLess, UnaryExpressions) { + const auto *t = IR::Type_Bits::get(16, false); + const auto *a = new IR::Constant(t, 5); + const auto *b = new IR::Constant(t, 10); + const auto *c = new IR::Constant(t, 5); + + const auto *p1 = new IR::LNot(a); + const auto *p2 = new IR::LNot(b); + const auto *p3 = new IR::LNot(c); + + CHECK_LESS_FOR_P4C_EXPRESSION(p1, p2); + CHECK_GREATER_FOR_P4C_EXPRESSION(p2, p1); + CHECK_EQUALITY_FOR_P4C_EXPRESSION(p1, p3); +} + +TEST(OperatorLess, BinaryExpressions) { + const auto *t = IR::Type_Bits::get(16, false); + const auto *a = new IR::Constant(t, 5); + const auto *b = new IR::Constant(t, 10); + const auto *c = new IR::Constant(t, 5); + + const auto *p1 = new IR::Add(a, b); + const auto *p2 = new IR::Add(b, a); + const auto *p3 = new IR::Add(c, b); + + CHECK_LESS_FOR_P4C_EXPRESSION(p1, p2); + CHECK_GREATER_FOR_P4C_EXPRESSION(p2, p1); + CHECK_EQUALITY_FOR_P4C_EXPRESSION(p1, p3); +} + +TEST(OperatorLess, TernaryExpressions) { + const auto *t = IR::Type_Bits::get(16, false); + const auto *a = new IR::Constant(t, 5); + const auto *b = new IR::Constant(t, 10); + const auto *c = new IR::Constant(t, 5); + + auto *trueCond = new IR::BoolLiteral(true); + auto *falseCond = new IR::BoolLiteral(false); + + const auto *p1 = new IR::Mux(trueCond, a, c); + const auto *p2 = new IR::Mux(trueCond, b, c); + const auto *p3 = new IR::Mux(trueCond, c, a); + + CHECK_LESS_FOR_P4C_EXPRESSION(p1, p2); + CHECK_GREATER_FOR_P4C_EXPRESSION(p2, p1); + CHECK_EQUALITY_FOR_P4C_EXPRESSION(p1, p3); + + const auto *p4 = new IR::Mux(falseCond, a, c); + + CHECK_LESS_FOR_P4C_EXPRESSION(p4, p1); + CHECK_LESS_FOR_P4C_EXPRESSION(p4, p2); + CHECK_GREATER_FOR_P4C_EXPRESSION(p1, p4); + CHECK_GREATER_FOR_P4C_EXPRESSION(p2, p4); +} + +TEST(OperatorLess, PathExpressions) { + const auto *a = new IR::PathExpression("a"); + const auto *b = new IR::PathExpression("b"); + const auto *a2 = new IR::PathExpression("a"); + + CHECK_LESS_FOR_P4C_EXPRESSION(a, b); + CHECK_GREATER_FOR_P4C_EXPRESSION(b, a); + CHECK_EQUALITY_FOR_P4C_EXPRESSION(a, a2); + + const auto *c = new IR::PathExpression(new IR::Path("c")); + auto *d = new IR::PathExpression(new IR::Path("d")); + const auto *c2 = new IR::PathExpression(new IR::Path("c")); + + CHECK_LESS_FOR_P4C_EXPRESSION(c, d); + CHECK_GREATER_FOR_P4C_EXPRESSION(d, c); + CHECK_EQUALITY_FOR_P4C_EXPRESSION(c, c2); +} + +TEST(OperatorLess, Members) { + const auto *a = new IR::Member(new IR::PathExpression("a"), "b"); + const auto *b = new IR::Member(new IR::PathExpression("a"), "c"); + const auto *a2 = new IR::Member(new IR::PathExpression("a"), "b"); + + CHECK_LESS_FOR_P4C_EXPRESSION(a, b); + CHECK_GREATER_FOR_P4C_EXPRESSION(b, a); + CHECK_EQUALITY_FOR_P4C_EXPRESSION(a, a2); + + const auto *c = new IR::Member(new IR::PathExpression("b"), "a"); + CHECK_LESS_FOR_P4C_EXPRESSION(a, c); + CHECK_GREATER_FOR_P4C_EXPRESSION(c, a); +} + +TEST(OperatorLess, Declarations) { + const auto *t = IR::Type_Bits::get(16, false); + { + const auto *a = new IR::Declaration_Variable("a", t); + const auto *b = new IR::Declaration_Variable("b", t); + const auto *a2 = new IR::Declaration_Variable("a", t); + + CHECK_LESS_FOR_P4C_DECLARATION(a, b); + CHECK_GREATER_FOR_P4C_DECLARATION(b, a); + CHECK_EQUALITY_FOR_P4C_DECLARATION(a, a2); + } + + { + const auto *a = new IR::Declaration_Constant("a", t, new IR::Constant(t, 0)); + const auto *b = new IR::Declaration_Constant("b", t, new IR::Constant(t, 0)); + const auto *a2 = new IR::Declaration_Constant("a", t, new IR::Constant(t, 0)); + CHECK_LESS_FOR_P4C_DECLARATION(a, b); + CHECK_GREATER_FOR_P4C_DECLARATION(b, a); + CHECK_EQUALITY_FOR_P4C_DECLARATION(a, a2); + + const auto *a3 = new IR::Declaration_Constant("a", t, new IR::Constant(t, 1)); + CHECK_LESS_FOR_P4C_DECLARATION(a, a3); + CHECK_GREATER_FOR_P4C_DECLARATION(a3, a); + } +} diff --git a/tools/ir-generator/ir-generator-lex.l b/tools/ir-generator/ir-generator-lex.l index 76435760182..451d51ec998 100644 --- a/tools/ir-generator/ir-generator-lex.l +++ b/tools/ir-generator/ir-generator-lex.l @@ -86,6 +86,7 @@ static std::string comment_block; "#apply" { return APPLY; } "#no"[a-zA-Z_]* { yylval.str = cstring(yytext+3); return NO; } "#nooperator==" { yylval.str = cstring(yytext+3); return NO; } +"#nooperator<" { yylval.str = cstring(yytext+3); return NO; } "0" { yylval.str = cstring(yytext); return ZERO; } -?[0-9]+ { yylval.str = cstring(yytext); return INTEGER; } diff --git a/tools/ir-generator/irclass.cpp b/tools/ir-generator/irclass.cpp index 501c4f6774b..6507244f486 100644 --- a/tools/ir-generator/irclass.cpp +++ b/tools/ir-generator/irclass.cpp @@ -112,6 +112,7 @@ void IrDefinitions::generate(std::ostream &t, std::ostream &out, std::ostream &i << "#include \"ir/ir-inline.h\" // IWYU pragma: keep\n" << "#include \"ir/json_generator.h\" // IWYU pragma: keep\n" << "#include \"ir/json_loader.h\" // IWYU pragma: keep\n" + << "#include \"ir/semantic_less.h\" // IWYU pragma: keep\n" << "#include \"ir/visitor.h\" // IWYU pragma: keep\n" << "#include \"lib/algorithm.h\" // IWYU pragma: keep\n" << "#include \"lib/log.h\" // IWYU pragma: keep\n" @@ -538,18 +539,24 @@ Util::Enumerator *IrClass::getUserMethods() const { [](IrElement *e) { return e != nullptr; }); } +bool IrClass::hasNoDirective(cstring feature) const { + return Util::enumerate(elements) + ->where([](IrElement *el) { return el->is(); }) + ->where([feature](IrElement *el) { return el->to()->text == feature; }) + ->any(); +} + bool IrClass::shouldSkip(cstring feature) const { - // skip if there is a 'no' directive - bool explicitNo = - Util::enumerate(elements) - ->where([](IrElement *el) { return el->is(); }) - ->where([feature](IrElement *el) { return el->to()->text == feature; }) - ->any(); - if (explicitNo) return true; - // also, skip if the user provided an implementation manually - // (except for validate) - if (feature == "validate") return false; + // Validate is special, it is never skipped. + if (feature == "validate") { + return false; + } + // Skip if there is a '#no' directive. + if (hasNoDirective(feature)) { + return true; + } + // Also skip if the user provided an implementation manually bool provided = Util::enumerate(elements) ->where([feature](IrElement *e) { const auto *m = e->to(); diff --git a/tools/ir-generator/irclass.h b/tools/ir-generator/irclass.h index 214d04efa94..fee3ad29758 100644 --- a/tools/ir-generator/irclass.h +++ b/tools/ir-generator/irclass.h @@ -250,6 +250,7 @@ class IrClass : public IrElement { int generateConstructor(const ctor_args_t &args, const IrMethod *user, unsigned skip_opt); void generateMethods(); bool shouldSkip(cstring feature) const; + bool hasNoDirective(cstring feature) const; public: const IrClass *getParent() const { diff --git a/tools/ir-generator/methods.cpp b/tools/ir-generator/methods.cpp index a75ad86b1a1..c610a65d838 100644 --- a/tools/ir-generator/methods.cpp +++ b/tools/ir-generator/methods.cpp @@ -126,6 +126,61 @@ const ordered_map IrMethod::Generate = { buf << cl->indent << "}"; return buf.str(); }}}, + {"isSemanticallyLess"_cs, + {&NamedType::Bool(), + {new IrField(new ReferenceType(new NamedType(IrClass::nodeClass()), true), "a_"_cs)}, + CONST + IN_IMPL + OVERRIDE, + [](IrClass *cl, Util::SourceInfo srcInfo, cstring body) -> cstring { + std::stringstream buf; + buf << "{" << std::endl; + buf << cl->indent << cl->indent + << "if (static_cast(this) == &a_) return false;\n"; + if (auto parent = cl->getParent()) { + buf << cl->indent << cl->indent + << "if (typeId() != a_.typeId()) " + "return typeId() < a_.typeId();\n"; + if (parent->name != "Node") { + buf << cl->indent << cl->indent << "if (!" + << parent->qualified_name(cl->containedIn) << "::equiv(a_)) return " + << parent->qualified_name(cl->containedIn) << "::isSemanticallyLess(a_);\n"; + } + } + if (body) { + buf << cl->indent << cl->indent << "auto &a = static_castname + << " &>(a_);\n"; + buf << LineDirective(srcInfo, true) << body; + } else { + bool first = true; + for (auto f : *cl->getFields()) { + if (*f->type == NamedType::SourceInfo()) + continue; // FIXME -- deal with SourcInfo + if (first) { + buf << cl->indent << cl->indent << "auto &a = static_castname + << " &>(a_);\n"; + buf << cl->indent << cl->indent << "return "; + first = false; + } else { + buf << std::endl << cl->indent << cl->indent << "|| "; + } + if (f->type->resolve(cl->containedIn) == nullptr) { + // This is not an IR pointer + buf << "IR::isSemanticallyLess(" << f->name << ", a." << f->name << ")"; + } else if (f->isInline) { + buf << f->name << ".isSemanticallyLess(a." << f->name << ")"; + } else { + buf << "(" << f->name << " != nullptr ? a." << f->name << " != nullptr ? " + << f->name << "->isSemanticallyLess(*a." << f->name << ")" + << " : false : a." << f->name << " != nullptr)"; + } + } + if (first) { // no fields? + buf << cl->indent << cl->indent << "return typeId() < a_.typeId()"; + } + buf << ";" << std::endl; + } + buf << cl->indent << "}"; + return buf.str(); + }}}, {"operator<<"_cs, {&ReferenceType::OstreamRef, {new IrField(&ReferenceType::OstreamRef, "out"_cs)}, @@ -320,15 +375,20 @@ void IrClass::generateMethods() { } } } - for (auto *parent = getParent(); parent; parent = parent->getParent()) { - auto eq_overload = new IrMethod("operator=="_cs, "{ return a == *this; }"_cs); - eq_overload->clss = this; - eq_overload->isOverride = true; - eq_overload->rtype = &NamedType::Bool(); - eq_overload->args.push_back( - new IrField(new ReferenceType(new NamedType(parent), true), "a"_cs)); - eq_overload->isConst = true; - elements.push_back(eq_overload); + for (const auto *parent = getParent(); parent != nullptr; parent = parent->getParent()) { + if (!hasNoDirective("operator=="_cs)) { + std::stringstream buf; + buf << "{ return a.operator==(*this); }"; + auto *eq_overload = new IrMethod("operator=="_cs, buf.str()); + eq_overload->clss = this; + eq_overload->isOverride = true; + eq_overload->inImpl = true; + eq_overload->rtype = &NamedType::Bool(); + eq_overload->args.push_back( + new IrField(new ReferenceType(new NamedType(parent), true), "a"_cs)); + eq_overload->isConst = true; + elements.push_back(eq_overload); + } } } IrMethod *ctor = nullptr;