Skip to content

Commit

Permalink
Flatten the genEq function and support struct expressions.
Browse files Browse the repository at this point in the history
  • Loading branch information
fruffy committed Nov 13, 2023
1 parent 04eeea8 commit d71aebe
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 101 deletions.
151 changes: 69 additions & 82 deletions backends/p4tools/common/lib/gen_eq.cpp
Original file line number Diff line number Diff line change
@@ -1,112 +1,99 @@
#include "backends/p4tools/common/lib/gen_eq.h"

#include <cstddef>
#include <string>
#include <vector>

#include "ir/irutils.h"
#include "ir/vector.h"
#include "lib/cstring.h"
#include "lib/exceptions.h"

namespace P4Tools {

const IR::Expression *GenEq::equate(const IR::Expression *target, const IR::Expression *keyset) {
if (const auto *defaultKey = keyset->to<IR::DefaultExpression>()) {
return equate(target, defaultKey);
}

if (const auto *listKey = keyset->to<IR::ListExpression>()) {
return equate(target, listKey);
const IR::Expression *GenEq::checkSingleton(const IR::Expression *expr) {
if (const auto *listExpr = expr->to<IR::BaseListExpression>()) {
if (listExpr->size() == 1) {
expr = checkSingleton(listExpr->components.at(0));
}
} else if (const auto *structExpr = expr->to<IR::StructExpression>()) {
if (structExpr->size() == 1) {
expr = checkSingleton(structExpr->components.at(0)->expression);
}
}
return expr;
}

if (const auto *maskKey = keyset->to<IR::Mask>()) {
return equate(target, maskKey);
const IR::Expression *GenEq::equateListTypes(const IR::Expression *left,
const IR::Expression *right) {
std::vector<const IR::Expression *> leftElems;
if (auto listExpr = left->to<IR::BaseListExpression>()) {
leftElems = IR::flattenListExpression(listExpr);
} else if (auto structExpr = left->to<IR::StructExpression>()) {
leftElems = IR::flattenStructExpression(structExpr);
} else {
BUG("Unsupported list expression %1% of type %2%.", left, left->node_type_name());
}

if (const auto *rangeKey = keyset->to<IR::Range>()) {
return equate(target, rangeKey);
std::vector<const IR::Expression *> rightElems;
if (auto listExpr = right->to<IR::BaseListExpression>()) {
rightElems = IR::flattenListExpression(listExpr);
} else if (auto structExpr = right->to<IR::StructExpression>()) {
rightElems = IR::flattenStructExpression(structExpr);
} else {
BUG("Unsupported right list expression %1% of type %2%.", right, right->node_type_name());
}
auto leftElemsSize = leftElems.size();
auto rightElemsSize = rightElems.size();
BUG_CHECK(leftElemsSize == rightElemsSize,
"The size of left elements (%1%) and the size of right elements (%2%) are "
"different.",
leftElemsSize, rightElemsSize);

// If the target is a list expression, it had better be a singleton. In this case, recurse into
// the singleton element.
if (const auto *listTarget = target->to<IR::ListExpression>()) {
BUG_CHECK(listTarget->size() == 1, "Cannot match %1% with %2%", target, keyset);
return equate(listTarget->components.at(0), keyset);
const IR::Expression *result = new IR::BoolLiteral(IR::Type::Boolean::get(), true);
bool firstLoop = true;
for (size_t i = 0; i < leftElems.size(); i++) {
const auto *conjunct = equate(leftElems.at(i), rightElems.at(i));
if (firstLoop) {
result = conjunct;
firstLoop = false;
} else {
result = new IR::LAnd(IR::Type::Boolean::get(), result, conjunct);
}
}

return mkEq(target, keyset);
return result;
}

const IR::Expression *GenEq::equate(const IR::Expression * /*target*/,
const IR::DefaultExpression * /*keyset*/) {
return new IR::BoolLiteral(IR::Type::Boolean::get(), true);
}
const IR::Expression *GenEq::equate(const IR::Expression *left, const IR::Expression *right) {
// First, recursively unroll any singleton elements.
left = checkSingleton(left);
right = checkSingleton(right);

const IR::Expression *GenEq::equate(const IR::Expression *target,
const IR::ListExpression *keyset) {
// If the keyset is a singleton list, recurse into the singleton element.
if (keyset->size() == 1) {
return equate(target, keyset->components.at(0));
// A single default expression can be matched with a list expression.
if (left->is<IR::DefaultExpression>() || right->is<IR::DefaultExpression>()) {
return new IR::BoolLiteral(IR::Type::Boolean::get(), true);
}

const auto *listTarget = target->to<IR::ListExpression>();
BUG_CHECK(listTarget, "Cannot match %1% with %2%", target, keyset);
return equate(listTarget, keyset);
}

const IR::Expression *GenEq::equate(const IR::Expression *target, const IR::Mask *keyset) {
// If the target is a list expression, it had better be a singleton. In this case, recurse into
// the singleton element.
if (const auto *listTarget = target->to<IR::ListExpression>()) {
BUG_CHECK(listTarget->size() == 1, "Cannot match %1% with %2%", target, keyset);
return equate(listTarget->components.at(0), keyset);
// If we still have lists after unrolling, compare them.
if (left->is<IR::BaseListExpression>() || left->is<IR::StructExpression>()) {
BUG_CHECK(right->is<IR::BaseListExpression>() || right->is<IR::StructExpression>(),
"Right expression must be a list expression. Is %1% of type %2%.", right,
right->node_type_name());
return equateListTypes(left, right);
}

// Let a &&& b represent the keyset.
// We return a & b == target & b.
return mkEq(new IR::BAnd(target->type, keyset->left, keyset->right),
new IR::BAnd(target->type, target, keyset->right));
}

const IR::Expression *GenEq::equate(const IR::Expression *target, const IR::Range *keyset) {
// If the target is a list expression, it had better be a singleton. In this case, recurse into
// the singleton element.
if (const auto *listTarget = target->to<IR::ListExpression>()) {
BUG_CHECK(listTarget->size() == 1, "Cannot match %1% with %2%", target, keyset);
return equate(listTarget->components.at(0), keyset);
// At this point, all lists must be resolved.
if (const auto *maskKey = right->to<IR::Mask>()) {
// Let a &&& b represent the keyset.
// We return a & b == target & b.
return mkEq(new IR::BAnd(left->type, maskKey->left, maskKey->right),
new IR::BAnd(left->type, left, maskKey->right));
}

const auto *boolType = IR::Type::Boolean::get();
return new IR::LAnd(boolType, new IR::Leq(boolType, keyset->left, target),
new IR::Leq(boolType, target, keyset->right));
}

const IR::Expression *GenEq::equate(const IR::ListExpression *target,
const IR::ListExpression *keyset) {
// If the keyset is a singleton list, recurse into the singleton element. Similarly for the
// target.
if (keyset->size() == 1) {
return equate(target, keyset->components.at(0));
}
if (target->size() == 1) {
return equate(target->components.at(0), keyset);
if (const auto *rangeKey = right->to<IR::Range>()) {
const auto *boolType = IR::Type::Boolean::get();
return new IR::LAnd(boolType, new IR::Leq(boolType, rangeKey->left, left),
new IR::Leq(boolType, left, rangeKey->right));
}

BUG_CHECK(target->size() == keyset->size(), "Cannot match %1% with %2%", target, keyset);

const IR::Expression *result = new IR::BoolLiteral(IR::Type::Boolean::get(), true);
bool firstLoop = true;
for (size_t i = 0; i < target->size(); i++) {
const auto *conjunct = equate(target->components.at(i), keyset->components.at(i));
if (firstLoop) {
result = conjunct;
firstLoop = false;
} else {
result = new IR::LAnd(IR::Type::Boolean::get(), result, conjunct);
}
}

return result;
return mkEq(left, right);
}

const IR::Equ *GenEq::mkEq(const IR::Expression *e1, const IR::Expression *e2) {
Expand Down
24 changes: 10 additions & 14 deletions backends/p4tools/common/lib/gen_eq.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,23 @@

namespace P4Tools {

/// Generates an equality on a target expression and a keyset expression, recursing into lists.
/// Generates an equality on two input expressions, recursing into lists and structs.
/// This supports fuzzy matching on singleton lists: singleton lists are considered the same as
/// their singleton elements. This is implemented by eagerly recursing into singleton lists before
/// attempting to generate the equality.
class GenEq {
public:
static const IR::Expression *equate(const IR::Expression *target, const IR::Expression *keyset);
static const IR::Expression *equate(const IR::Expression *left, const IR::Expression *right);

private:
static const IR::Expression *equate(const IR::Expression *target,
const IR::DefaultExpression *keyset);

static const IR::Expression *equate(const IR::Expression *target,
const IR::ListExpression *keyset);

static const IR::Expression *equate(const IR::Expression *target, const IR::Mask *keyset);

static const IR::Expression *equate(const IR::Expression *target, const IR::Range *keyset);

static const IR::Expression *equate(const IR::ListExpression *target,
const IR::ListExpression *keyset);
/// Recursively resolve lists of size 1 by returning the expression contained within.
static const IR::Expression *checkSingleton(const IR::Expression *expr);

/// Flatten and compare two lists.
/// Important, this equation assumes that struct expressions have been ordered.
/// This calculation does not match the names of the struct expressions.
static const IR::Expression *equateListTypes(const IR::Expression *left,
const IR::Expression *right);

/// Convenience method for producing a typed Eq node on the given expressions.
static const IR::Equ *mkEq(const IR::Expression *e1, const IR::Expression *e2);
Expand Down
2 changes: 1 addition & 1 deletion ir/expression.def
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ class StructExpression : Expression {
validate {
components.check_null(); components.validate();
BUG_CHECK(structType == nullptr || structType->is<IR::Type_Name>() ||
structType->is<IR::Type_Specialized>(),
structType->is<IR::Type_Specialized>() || structType->is<IR::Type_StructLike>(),
"%1%: unexpected struct type", this);
}
size_t size() const { return components.size(); }
Expand Down
2 changes: 1 addition & 1 deletion ir/irutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ std::vector<const Expression *> flattenStructExpression(const StructExpression *
return exprList;
}

std::vector<const Expression *> flattenListExpression(const ListExpression *listExpr) {
std::vector<const Expression *> flattenListExpression(const BaseListExpression *listExpr) {
std::vector<const Expression *> exprList;
for (const auto *listElem : listExpr->components) {
if (const auto *subListExpr = listElem->to<ListExpression>()) {
Expand Down
6 changes: 3 additions & 3 deletions ir/irutils.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace IR {
class BoolLiteral;
class Constant;
class Expression;
class ListExpression;
class BaseListExpression;
class Literal;
class StructExpression;
class Type;
Expand Down Expand Up @@ -76,9 +76,9 @@ const IR::Constant *convertBoolLiteral(const IR::BoolLiteral *lit);
/// This is why we require two separate methods.
std::vector<const Expression *> flattenStructExpression(const StructExpression *structExpr);

/// Given an ListExpression, returns a flat vector of the expressions contained in that
/// Given an BaseListExpression, returns a flat vector of the expressions contained in that
/// list.
std::vector<const Expression *> flattenListExpression(const ListExpression *listExpr);
std::vector<const Expression *> flattenListExpression(const BaseListExpression *listExpr);

/* =========================================================================================
* Other helper functions
Expand Down

0 comments on commit d71aebe

Please sign in to comment.