diff --git a/.github/workflows/ci-native.yml b/.github/workflows/ci-native.yml index 9d58a77..6eb78d2 100644 --- a/.github/workflows/ci-native.yml +++ b/.github/workflows/ci-native.yml @@ -25,10 +25,10 @@ jobs: # See https://docs.bazel.build/versions/master/output_directories.html path: "~/.cache/bazel" # See https://docs.github.com/en/actions/guides/caching-dependencies-to-speed-up-workflows - key: ${{ runner.os }}-build-${{ hashFiles('**/*.bzl', '**/*.bazel') }} + key: new-${{ runner.os }}-build-${{ hashFiles('**/*.bzl', '**/*.bazel') }} restore-keys: | - ${{ runner.os }}-build- - ${{ runner.os }}- + new-${{ runner.os }}-build- + new-${{ runner.os }}- - name: Install p4c system dependencies (Flex, Bison, GMP) run: sudo apt-get update && sudo apt-get install bison flex libfl-dev libgmp-dev @@ -55,10 +55,10 @@ jobs: # See https://docs.bazel.build/versions/master/output_directories.html path: "~/.cache/bazel" # See https://docs.github.com/en/actions/guides/caching-dependencies-to-speed-up-workflows - key: ${{ runner.os }}-test-${{ hashFiles('**/*.bzl', '**/*.bazel') }} + key: new-${{ runner.os }}-test-${{ hashFiles('**/*.bzl', '**/*.bazel') }} restore-keys: | - ${{ runner.os }}-test- - ${{ runner.os }}- + new-${{ runner.os }}-test- + new-${{ runner.os }}- - name: Install p4c system dependencies (Flex, Bison, GMP) run: sudo apt-get update && sudo apt-get install bison flex libfl-dev libgmp-dev diff --git a/docs/language-specification.md b/docs/language-specification.md index 34d37f1..1989c62 100644 --- a/docs/language-specification.md +++ b/docs/language-specification.md @@ -73,11 +73,19 @@ table: | exact | k::value | bit\ | | ternary | k::value | bit\ | | | k::mask | bit\ | +| optional | k::value | bit\ | +| | k::mask | bit\ | | lpm | k::value | bit\ | | | k::prefix_length | int | | range | k::low | bit\ | | | k::high | bit\ | +Note that an `optional` match is just a restricted kind of `ternary` match whose mask always satisfies the following constraint: +``` +// Exact match or wildcard match. +optional_match_key::mask == 0 || optional_match_key::mask == -1 +``` + When `k` is of type `bool`, everything behaves precisely as if `k` was of type `bit<1>`, with the boolean constant `true` and `false` being mapped to `1` and `0`, respectively. diff --git a/e2e_tests/invalid_constraints.expected.output b/e2e_tests/invalid_constraints.expected.output index c95a545..77bf75d 100644 --- a/e2e_tests/invalid_constraints.expected.output +++ b/e2e_tests/invalid_constraints.expected.output @@ -61,3 +61,9 @@ Type error: expected type int, got bool 84 | !false -> -8 == -0b1000 || !1 | ^ Type error: expected type bool, got int + +- e2e_tests/invalid_constraints.p4:92:5-21: + | @entry_restriction(" +92 | optional_key > 10; + | ^^^^^^^^^^^^^^^^^ +Type error: operand type optional<32> does not support ordered comparison diff --git a/e2e_tests/invalid_constraints.p4 b/e2e_tests/invalid_constraints.p4 index f5ad417..9a3d20b 100644 --- a/e2e_tests/invalid_constraints.p4 +++ b/e2e_tests/invalid_constraints.p4 @@ -85,6 +85,17 @@ control invalid_constraints(inout headers_t headers, ") table boolean_negation_of_integer { actions = {} key = {} } + + @file(__FILE__) + @line(__LINE__) + @entry_restriction(" + optional_key > 10; + ") + table optional_does_not_support_ordered_comparison { + key = { headers.ipv4.dst_addr : optional @name("optional_key"); } + actions = {} + } + apply { forgot_quotes.apply(); forgot_quotes_with_srcloc.apply(); @@ -99,6 +110,7 @@ control invalid_constraints(inout headers_t headers, scalar_has_no_field.apply(); arithmetic_negation_of_boolean.apply(); boolean_negation_of_integer.apply(); + optional_does_not_support_ordered_comparison.apply(); } } diff --git a/e2e_tests/table_entries/optional_match_table_invalid_1.pb.txt b/e2e_tests/table_entries/optional_match_table_invalid_1.pb.txt new file mode 100644 index 0000000..4950c67 --- /dev/null +++ b/e2e_tests/table_entries/optional_match_table_invalid_1.pb.txt @@ -0,0 +1,8 @@ +table_id: 4 +match { + field_id: 1 # hdr.ipv4.dst_addr + optional { + # Illegal: constraint only allows wildcard match. + value: "123" + } +} diff --git a/e2e_tests/table_entries/optional_match_table_valid_1.pb.txt b/e2e_tests/table_entries/optional_match_table_valid_1.pb.txt new file mode 100644 index 0000000..d13053c --- /dev/null +++ b/e2e_tests/table_entries/optional_match_table_valid_1.pb.txt @@ -0,0 +1,2 @@ +table_id: 4 +# No matches at all is valid, since it corresponds to "don't care". diff --git a/e2e_tests/valid_constraints.expected.output b/e2e_tests/valid_constraints.expected.output index 7915a33..063d38a 100644 --- a/e2e_tests/valid_constraints.expected.output +++ b/e2e_tests/valid_constraints.expected.output @@ -2,6 +2,10 @@ e2e_tests/table_entries/accept_all_entries_1.pb.txt: constraint satisfied e2e_tests/table_entries/acl_table_1.pb.txt: constraint violated +e2e_tests/table_entries/optional_match_table_invalid_1.pb.txt: constraint violated + +e2e_tests/table_entries/optional_match_table_valid_1.pb.txt: constraint satisfied + e2e_tests/table_entries/reject_all_entries_1.pb.txt: constraint violated e2e_tests/table_entries/unknown_table_entry.pb.txt: Error: INVALID_ARGUMENT: table entry with unknown table ID 0 (full ID: 33554432 (0x02000000)) diff --git a/e2e_tests/valid_constraints.p4 b/e2e_tests/valid_constraints.p4 index e26128d..45e1e4d 100644 --- a/e2e_tests/valid_constraints.p4 +++ b/e2e_tests/valid_constraints.p4 @@ -3,15 +3,20 @@ control valid_constraints(inout headers_t hdr, inout local_metadata_t local_metadata, inout standard_metadata_t standard_metadata) { - + @file(__FILE__) + @line(__LINE__) @entry_restriction("true") @id(1) table accept_all_entries { key = {} actions = {} } + @file(__FILE__) + @line(__LINE__) @entry_restriction("false") @id(2) table reject_all_entries { key = {} actions = {} } + @file(__FILE__) + @line(__LINE__) @entry_restriction(" // Either wildcard or exact match (i.e., "optional" match). hdr.ipv4.dst_addr::mask == 0 || hdr.ipv4.dst_addr::mask == -1; @@ -38,17 +43,35 @@ control valid_constraints(inout headers_t hdr, hdr.ipv4.dst_addr : ternary; standard_metadata.ingress_port: ternary; hdr.ipv6.dst_addr : ternary; - hdr.ipv4.src_addr : ternary; + hdr.ipv4.src_addr : optional; local_metadata.dscp : ternary; local_metadata.is_ip_packet : ternary; } actions = { } } + @file(__FILE__) + @line(__LINE__) + @entry_restriction(" + // Vacuously true, just to test syntax and implicit conversions. + hdr.ipv4.dst_addr::mask == 0 || hdr.ipv4.dst_addr::mask == -1; + hdr.ipv4.dst_addr::value == 10 || hdr.ipv4.dst_addr::value != 10; + // Same as above, but using implicit conversion. + hdr.ipv4.dst_addr == 10 || hdr.ipv4.dst_addr != 10; + // A real constraint: only wildcard match is okay. + hdr.ipv4.dst_addr::mask == 0; + ") + @id(4) + table optional_match_table { + key = { hdr.ipv4.dst_addr : optional; } + actions = {} + } + apply { accept_all_entries.apply(); reject_all_entries.apply(); vrf_classifier_table.apply(); + optional_match_table.apply(); } } diff --git a/p4_constraints/ast.cc b/p4_constraints/ast.cc index 109e26a..a2d5f4c 100644 --- a/p4_constraints/ast.cc +++ b/p4_constraints/ast.cc @@ -64,9 +64,13 @@ std::string TypeName(const Type& type) { return absl::StrCat("bit<", type.lpm().bitwidth(), ">"); case Type::kRange: return absl::StrCat("range<", type.range().bitwidth(), ">"); - default: - return "???"; + case Type::kOptionalMatch: + return absl::StrCat("optional<", type.optional_match().bitwidth(), ">"); + case Type::TYPE_NOT_SET: + break; } + LOG(DFATAL) << "invalid type: " << type.DebugString(); + return "???"; } std::ostream& operator<<(std::ostream& os, const Type& type) { @@ -97,8 +101,10 @@ absl::optional TypeBitwidth(const Type& type) { return type.lpm().bitwidth(); case Type::kRange: return type.range().bitwidth(); + case Type::kOptionalMatch: + return type.optional_match().bitwidth(); default: - return {}; + return absl::nullopt; } } @@ -119,6 +125,9 @@ bool SetTypeBitwidth(Type* type, int bitwidth) { case Type::kRange: type->mutable_range()->set_bitwidth(bitwidth); return true; + case Type::kOptionalMatch: + type->mutable_optional_match()->set_bitwidth(bitwidth); + return true; default: return false; } @@ -129,35 +138,38 @@ Type TypeCaseToType(Type::TypeCase type_case) { switch (type_case) { case Type::kUnknown: type.mutable_unknown(); - break; + return type; case Type::kUnsupported: type.mutable_unsupported(); - break; + return type; case Type::kBoolean: type.mutable_boolean(); - break; + return type; case Type::kArbitraryInt: type.mutable_arbitrary_int(); - break; + return type; case Type::kFixedUnsigned: type.mutable_fixed_unsigned(); - break; + return type; case Type::kExact: type.mutable_exact(); - break; + return type; case Type::kTernary: type.mutable_ternary(); - break; + return type; case Type::kLpm: type.mutable_lpm(); - break; + return type; case Type::kRange: type.mutable_range(); + return type; + case Type::kOptionalMatch: + type.mutable_optional_match(); + return type; + case Type::TYPE_NOT_SET: break; - default: - LOG(DFATAL) << "unknown type case: " << type_case; } - DCHECK_EQ(type.type_case(), type_case); + LOG(DFATAL) << "invalid type case: " << type_case; return type; } diff --git a/p4_constraints/ast.proto b/p4_constraints/ast.proto index 46ca487..da7b08c 100644 --- a/p4_constraints/ast.proto +++ b/p4_constraints/ast.proto @@ -101,6 +101,8 @@ message Type { Ternary ternary = 7; Lpm lpm = 8; Range range = 9; + // `optional` is a reserved name, so we use `optional_match` instead. + Optional optional_match = 10; } // Before type-checking, types may be unknown. @@ -141,6 +143,11 @@ message Type { message Range { int32 bitwidth = 1; // required } + + // Optional match, aka "Optional". + message Optional { + int32 bitwidth = 1; // required + } } // Represents the location of a character relative to a source file or table. diff --git a/p4_constraints/backend/constraint_info.cc b/p4_constraints/backend/constraint_info.cc index baf38ec..e7c26d8 100644 --- a/p4_constraints/backend/constraint_info.cc +++ b/p4_constraints/backend/constraint_info.cc @@ -110,6 +110,9 @@ absl::StatusOr ParseKeyType(const MatchField& key) { case MatchField::RANGE: type.mutable_range()->set_bitwidth(key.bitwidth()); return type; + case MatchField::OPTIONAL: + type.mutable_optional_match()->set_bitwidth(key.bitwidth()); + return type; default: return gutils::InvalidArgumentErrorBuilder(GUTILS_LOC) << "match key of invalid MatchType: " diff --git a/p4_constraints/backend/interpreter.cc b/p4_constraints/backend/interpreter.cc index 84c6e6b..8e3cc92 100644 --- a/p4_constraints/backend/interpreter.cc +++ b/p4_constraints/backend/interpreter.cc @@ -64,7 +64,7 @@ std::string P4IDToString(uint32_t p4_object_id) { // -- Parsing P4RT table entries ----------------------------------------------- // See https://p4.org/p4runtime/spec/master/P4Runtime-Spec.html#sec-bytestrings. -absl::StatusOr ParseP4RTInteger(const std::string& int_str) { +static absl::StatusOr ParseP4RTInteger(const std::string& int_str) { mpz_class integer; const char* chars = int_str.c_str(); const size_t char_count = strlen(chars); @@ -78,6 +78,11 @@ absl::StatusOr ParseP4RTInteger(const std::string& int_str) { return integer; } +static Integer MaxValueForBitwidth(int bitwidth) { + // 2^bitwidth - 1 + return (mpz_class(1) << bitwidth) - mpz_class(1); +} + // Returns (table key name, table key value)-pair. absl::StatusOr> ParseKey( const p4::v1::FieldMatch& p4field, const TableInfo& table_info) { @@ -130,6 +135,18 @@ absl::StatusOr> ParseKey( return {std::make_pair(key.name, Range{.low = low, .high = high})}; } + case p4::v1::FieldMatch::kOptional: { + RET_CHECK_EQ(key.type.type_case(), Type::kOptionalMatch) + << "P4RT table entry inconsistent with P4 program"; + ASSIGN_OR_RETURN( + Integer value, ParseP4RTInteger(p4field.optional().value()), + _ << " while parsing field 'value' of optional key " << key.name); + return {std::make_pair( + key.name, Ternary{.value = value, + .mask = MaxValueForBitwidth( + key.type.optional_match().bitwidth())})}; + } + default: return gutils::InvalidArgumentErrorBuilder(GUTILS_LOC) << "unsupported P4RT field match type " @@ -139,17 +156,53 @@ absl::StatusOr> ParseKey( absl::StatusOr ParseEntry(const p4::v1::TableEntry& entry, const TableInfo& table_info) { + // Parse all keys that are explicitly present. absl::flat_hash_map keys; for (const p4::v1::FieldMatch& field : entry.match()) { - ASSIGN_OR_RETURN(auto kv, ParseKey(field, table_info), - _ << " while parsing P4RT table entry"); + ASSIGN_OR_RETURN(auto kv, ParseKey(field, table_info)); auto result = keys.insert(kv); if (result.second == false) { return gutils::InvalidArgumentErrorBuilder(GUTILS_LOC) - << "Unable to parse P4RT table entry: duplicate match on key " - << kv.first << " with ID " << P4IDToString(field.field_id()); + << "duplicate match on key " << kv.first << " with ID " + << P4IDToString(field.field_id()); + } + } + + // Use default value for omitted keys. + // See Section 9.1.1. of the P4runtime specification. + for (const auto& [name, key_info] : table_info.keys_by_name) { + if (keys.contains(name)) continue; + switch (key_info.type.type_case()) { + case ast::Type::kExact: + return gutils::InvalidArgumentErrorBuilder(GUTILS_LOC) + << "missing exact match key '" << key_info.name << "'"; + case ast::Type::kTernary: + case ast::Type::kOptionalMatch: + keys[name] = Ternary{}; + continue; + case ast::Type::kLpm: + keys[name] = Lpm{}; + continue; + case ast::Type::kRange: + keys[name] = Range{ + .low = mpz_class(0), + .high = MaxValueForBitwidth(key_info.type.range().bitwidth()), + }; + continue; + case ast::Type::kBoolean: + case ast::Type::kArbitraryInt: + case ast::Type::kFixedUnsigned: + case ast::Type::kUnknown: + case ast::Type::kUnsupported: + case ast::Type::TYPE_NOT_SET: + break; } + return gutils::InternalErrorBuilder(GUTILS_LOC) + << "Key '" << key_info.name + << "' of invalid match type detected at runtime: " + << key_info.type.DebugString(); } + return TableEntry{.table_name = table_info.name, .keys = keys}; } @@ -186,51 +239,53 @@ absl::StatusOr EvalAndCastTo(const Type& type, const Expression& expr, const TableEntry& entry) { ASSIGN_OR_RETURN(EvalResult result, Eval(expr, entry)); - if (!absl::holds_alternative(result)) - return TypeError(expr.start_location(), expr.end_location()) - << "expected expression of (or castable to) type " << type; - const Integer value = absl::get(result); - const Integer one = mpz_class(1); - const Integer zero = mpz_class(0); - const int bitwidth = TypeBitwidth(type).value_or(-1); - DCHECK_NE(bitwidth, -1) << "can only cast to fixed-size types"; - switch (type.type_case()) { - // int ~~> bit - // n |~> n mod 2^W - case Type::kFixedUnsigned: { - Integer domain_size = one << bitwidth; // 2^W - Integer fixed_value = value % domain_size; - // operator% may return negative values. - if (fixed_value < zero) fixed_value += domain_size; - return {fixed_value}; - } + if (absl::holds_alternative(result)) { + const Integer value = absl::get(result); + const Integer one = mpz_class(1); + const Integer zero = mpz_class(0); + const int bitwidth = TypeBitwidth(type).value_or(-1); + DCHECK_NE(bitwidth, -1) << "can only cast to fixed-size types"; + switch (type.type_case()) { + // int ~~> bit + // n |~> n mod 2^W + case Type::kFixedUnsigned: { + Integer domain_size = one << bitwidth; // 2^W + Integer fixed_value = value % domain_size; + // operator% may return negative values. + if (fixed_value < zero) fixed_value += domain_size; + return {fixed_value}; + } - // bit ~~> Exact - // n |~> Exact { value = n } - case Type::kExact: - return {Exact{.value = value}}; + // bit ~~> Exact + // n |~> Exact { value = n } + case Type::kExact: + return {Exact{.value = value}}; + + // bit ~~> Ternary/Optional + // n |~> Ternary { value = n; mask = 2^W-1 } + case Type::kTernary: + case Type::kOptionalMatch: { + Integer mask = (one << bitwidth) - one; // 2^W - 1 + return {Ternary{.value = value, .mask = mask}}; + } - // bit ~~> Ternary - // n |~> Ternary { value = n; mask = 2^W-1 } - case Type::kTernary: { - Integer mask = (one << bitwidth) - one; // 2^W - 1 - return {Ternary{.value = value, .mask = mask}}; - } + // bit ~~> LPM + // n |~> LPM { value = n; prefix_length = W } + case Type::kLpm: + return {Lpm{.value = value, .prefix_length = mpz_class(bitwidth)}}; - // bit ~~> LPM - // n |~> LPM { value = n; prefix_length = W } - case Type::kLpm: - return {Lpm{.value = value, .prefix_length = mpz_class(bitwidth)}}; + // bit ~~> Range + // n |~> Range { low = n; high = n } + case Type::kRange: + return {Range{.low = value, .high = value}}; - // bit ~~> Range - // n |~> Range { low = n; high = n } - case Type::kRange: - return {Range{.low = value, .high = value}}; - - default: - return gutils::InternalErrorBuilder(GUTILS_LOC) - << "don't know how to cast to type " << type; + default: + break; + } } + return TypeError(expr.start_location(), expr.end_location()) + << "cannot cast expression of type " << expr.type() << " to type " + << type; } absl::StatusOr EvalBinaryExpression(ast::BinaryOperator binop, @@ -336,12 +391,9 @@ struct EvalFieldAccess { if (field == "high") return {range.high}; return Error("range"); } + absl::StatusOr operator()(bool) { return Error("bool"); } - absl::StatusOr operator()(bool b) { return Error("bool"); } - - absl::StatusOr operator()(const Integer& i) { - return Error("int"); - } + absl::StatusOr operator()(const Integer&) { return Error("int"); } }; // -- Main evaluator ----------------------------------------------------------- @@ -367,9 +419,10 @@ absl::StatusOr Eval_(const Expression& expr, case Expression::kKey: { auto it = entry.keys.find(expr.key()); - if (it == entry.keys.end()) + if (it == entry.keys.end()) { TypeError(expr.start_location(), expr.end_location()) << "unknown key " << expr.key() << " in table " << entry.table_name; + } return it->second; } @@ -408,13 +461,10 @@ absl::StatusOr Eval_(const Expression& expr, } case Expression::EXPRESSION_NOT_SET: - return gutils::InvalidArgumentErrorBuilder(GUTILS_LOC) - << "invalid expression: " << expr.DebugString(); - - default: - return gutils::UnimplementedErrorBuilder(GUTILS_LOC) - << "unknown expression case: " << expr.expression_case(); + break; } + return gutils::InvalidArgumentErrorBuilder(GUTILS_LOC) + << "invalid expression: " << expr.DebugString(); } // -- Sanity checking ---------------------------------------------------------- @@ -424,25 +474,30 @@ absl::StatusOr Eval_(const Expression& expr, absl::Status DynamicTypeCheck(const Expression& expr, const EvalResult result) { switch (expr.type().type_case()) { case Type::kBoolean: - if (absl::holds_alternative(result)) return {}; + if (absl::holds_alternative(result)) return absl::OkStatus(); break; case Type::kArbitraryInt: case Type::kFixedUnsigned: - if (absl::holds_alternative(result)) return {}; + if (absl::holds_alternative(result)) return absl::OkStatus(); break; case Type::kExact: - if (absl::holds_alternative(result)) return {}; + if (absl::holds_alternative(result)) return absl::OkStatus(); break; case Type::kTernary: - if (absl::holds_alternative(result)) return {}; + if (absl::holds_alternative(result)) return absl::OkStatus(); break; case Type::kLpm: - if (absl::holds_alternative(result)) return {}; + if (absl::holds_alternative(result)) return absl::OkStatus(); break; case Type::kRange: - if (absl::holds_alternative(result)) return {}; + if (absl::holds_alternative(result)) return absl::OkStatus(); break; - default: + case Type::kOptionalMatch: + if (absl::holds_alternative(result)) return absl::OkStatus(); + break; + case Type::kUnknown: + case Type::kUnsupported: + case Type::TYPE_NOT_SET: break; } return TypeError(expr.start_location(), expr.end_location()) @@ -476,7 +531,9 @@ absl::StatusOr EntryMeetsConstraint(const p4::v1::TableEntry& entry, << "table entry with unknown table ID " << P4IDToString(entry.table_id()); const TableInfo& table_info = it->second; - ASSIGN_OR_RETURN(TableEntry parsed_entry, ParseEntry(entry, table_info)); + ASSIGN_OR_RETURN(TableEntry parsed_entry, ParseEntry(entry, table_info), + _ << " while parsing P4RT table entry for table '" + << table_info.name << "':"); // Check if entry satisfies table constraint (if present). if (!table_info.constraint.has_value()) { diff --git a/p4_constraints/backend/interpreter.h b/p4_constraints/backend/interpreter.h index e541c74..cf14edb 100644 --- a/p4_constraints/backend/interpreter.h +++ b/p4_constraints/backend/interpreter.h @@ -25,6 +25,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" +#include "absl/strings/str_format.h" #include "p4/v1/p4runtime.pb.h" #include "p4_constraints/ast.pb.h" #include "p4_constraints/backend/constraint_info.h" @@ -54,6 +55,8 @@ struct Exact { Integer value; }; +// Used to represent both ternary and optional keys at runtime, since an +// optional key is just a ternary key whose mask is all zeros or all ones. struct Ternary { Integer value; Integer mask; @@ -69,6 +72,11 @@ struct Range { Integer high; }; +// Evaluation can result in a value of various types. +// We use a tagged union to ease debugging (see DynamicTypeCheck); an untagged +// union would work just fine assuming the type checker has no bugs. +using EvalResult = absl::variant; + inline bool operator==(const Exact& left, const Exact& right) { return left.value == right.value; } @@ -85,14 +93,42 @@ inline bool operator==(const Range& left, const Range& right) { return left.low == right.low && left.high == right.high; } -// Evaluation can result in a value of various types. -// We use a tagged union to ease debugging (see DynamicTypeCheck); an untagged -// union would work just fine assuming the type checker has no bugs. -using EvalResult = absl::variant; +inline std::ostream& operator<<(std::ostream& os, const Integer& integer) { + return os << integer.get_str(); +} + +inline std::ostream& operator<<(std::ostream& os, const Exact& exact) { + return os << absl::StrFormat("Exact{.value = %s}", exact.value.get_str()); +} + +inline std::ostream& operator<<(std::ostream& os, const Ternary& ternary) { + return os << absl::StrFormat("Ternary{.value = %s, .mask = %s}", + ternary.value.get_str(), ternary.mask.get_str()); +} + +inline std::ostream& operator<<(std::ostream& os, const Lpm& lpm) { + return os << absl::StrFormat("Lpm{.value = %s, .prefix_length = %s}", + lpm.value.get_str(), + lpm.prefix_length.get_str()); +} + +inline std::ostream& operator<<(std::ostream& os, const Range& range) { + return os << absl::StrFormat("Range{.low = %s, .high = %s}", + range.low.get_str(), range.high.get_str()); +} + +// TODO(smolkaj): The code below does not compile with C++11. Find workaround. +// inline std::ostream& operator<<(std::ostream& os, const EvalResult& result) { +// absl::visit([&](const auto& result) { os << result; }, result); +// return os; +// } // Parsed representation of p4::v1::TableEntry. struct TableEntry { std::string table_name; + // All table keys, by name. + // In contrast to p4::v1::TableEntry, all keys must be present, i.e. this must + // be a total map from key names to values. absl::flat_hash_map keys; // TODO(smolkaj): once we support actions, they will be added here. }; diff --git a/p4_constraints/backend/interpreter_test.cc b/p4_constraints/backend/interpreter_test.cc index 8eea3fc..0f48ad3 100644 --- a/p4_constraints/backend/interpreter_test.cc +++ b/p4_constraints/backend/interpreter_test.cc @@ -58,42 +58,49 @@ class EntryMeetsConstraintTest : public ::testing::Test { const Type kTernary32 = ParseTextProtoOrDie("ternary { bitwidth: 32 }"); const Type kLpm32 = ParseTextProtoOrDie("lpm { bitwidth: 32 }"); const Type kRange32 = ParseTextProtoOrDie("range { bitwidth: 32 }"); + const Type kOptional32 = + ParseTextProtoOrDie("optional_match { bitwidth: 32 }"); const TableInfo kTableInfo{ - .id = 0, + .id = 1, .name = "table", .constraint = {}, // To be filled in later. - .keys_by_id = {}, // Not needed for testing. + .keys_by_id = + { + {1, {1, "exact32", kExact32}}, + // For testing purposes, fine to omit the other keys here. + }, .keys_by_name = { - {"unknown", {0, "unknown", kUnknown}}, - {"unsupported", {0, "unsupported", kUnsupported}}, - {"bool", {0, "bool", kBool}}, - {"int", {0, "int", kArbitraryInt}}, - {"bit16", {0, "bit16", kFixedUnsigned16}}, - {"bit32", {0, "bit32", kFixedUnsigned32}}, - {"exact32", {0, "exact32", kExact32}}, - {"ternary32", {0, "ternary32", kTernary32}}, - {"lpm32", {0, "lpm32", kLpm32}}, - {"range32", {0, "range32", kRange32}}, + {"exact32", {1, "exact32", kExact32}}, + {"ternary32", {2, "ternary32", kTernary32}}, + {"lpm32", {3, "lpm32", kLpm32}}, + {"range32", {4, "range32", kRange32}}, + {"optional32", {5, "optional32", kOptional32}}, }}; const TableEntry kParsedEntry{ .table_name = "table", .keys = { - {"unknown", {false}}, - {"unsupported", {false}}, - {"bool", {true}}, - {"int", {mpz_class("-1")}}, - {"bit16", {mpz_class("42")}}, - {"bit32", {mpz_class("200")}}, - {"exact32", {Exact{.value = mpz_class("13")}}}, + {"exact32", {Exact{.value = mpz_class(42)}}}, {"ternary32", - {Ternary{.value = mpz_class("12"), .mask = mpz_class("128")}}}, + {Ternary{.value = mpz_class(12), .mask = mpz_class(128)}}}, {"lpm32", - {Lpm{.value = mpz_class("0"), .prefix_length = mpz_class("32")}}}, - {"range32", {Range{.low = mpz_class("5"), .high = mpz_class("500")}}}, + {Lpm{.value = mpz_class(0), .prefix_length = mpz_class(32)}}}, + {"range32", {Range{.low = mpz_class(5), .high = mpz_class(500)}}}, + {"optional32", + {Ternary{.value = mpz_class(12), + .mask = (mpz_class(1) << 32) - mpz_class(1)}}}, }}; + const p4::v1::TableEntry kTableEntry = + ParseTextProtoOrDie(R"PROTO( + table_id: 1 + match { + field_id: 1 + exact { value: "1234" } + } + )PROTO"); + ConstraintInfo MakeConstraintInfo(const Expression& expr) { TableInfo table_info = kTableInfo; table_info.constraint = expr; @@ -126,33 +133,32 @@ class EvalTest : public EntryMeetsConstraintTest {}; TEST_F(EntryMeetsConstraintTest, EmptyExpressionErrors) { Expression expr; - p4::v1::TableEntry entry; - EXPECT_THAT(EntryMeetsConstraint(entry, MakeConstraintInfo(expr)), + EXPECT_THAT(EntryMeetsConstraint(kTableEntry, MakeConstraintInfo(expr)), StatusIs(StatusCode::kInvalidArgument)); } TEST_F(EntryMeetsConstraintTest, BooleanConstants) { - p4::v1::TableEntry entry; auto const_true = ExpressionWithType(kBool, "boolean_constant: true"); auto const_false = ExpressionWithType(kBool, "boolean_constant: false"); - EXPECT_THAT(EntryMeetsConstraint(entry, MakeConstraintInfo(const_true)), + EXPECT_THAT(EntryMeetsConstraint(kTableEntry, MakeConstraintInfo(const_true)), IsOkAndHolds(Eq(true))); - EXPECT_THAT(EntryMeetsConstraint(entry, MakeConstraintInfo(const_false)), - IsOkAndHolds(Eq(false))); + EXPECT_THAT( + EntryMeetsConstraint(kTableEntry, MakeConstraintInfo(const_false)), + IsOkAndHolds(Eq(false))); } TEST_F(EntryMeetsConstraintTest, NonBooleanConstraintsAreRejected) { - p4::v1::TableEntry entry; for (const Type& type : {kArbitraryInt, kFixedUnsigned16, kFixedUnsigned32}) { auto expr = ExpressionWithType(type, R"(integer_constant: "42")"); - EXPECT_THAT(EntryMeetsConstraint(entry, MakeConstraintInfo(expr)), + EXPECT_THAT(EntryMeetsConstraint(kTableEntry, MakeConstraintInfo(expr)), StatusIs(StatusCode::kInvalidArgument)); } // Expressions evaluating to non-scalar values should also be rejected. for (std::string key : {"exact32", "ternary32", "lpm32", "range32"}) { - EXPECT_THAT(EntryMeetsConstraint(entry, MakeConstraintInfo(KeyExpr(key))), - StatusIs(StatusCode::kInvalidArgument)); + EXPECT_THAT( + EntryMeetsConstraint(kTableEntry, MakeConstraintInfo(KeyExpr(key))), + StatusIs(StatusCode::kInvalidArgument)); } } diff --git a/p4_constraints/backend/type_checker.cc b/p4_constraints/backend/type_checker.cc index c833c38..7909b28 100644 --- a/p4_constraints/backend/type_checker.cc +++ b/p4_constraints/backend/type_checker.cc @@ -50,8 +50,8 @@ gutils::StatusBuilder TypeError(const SourceLocation& start, // Castability of types is given by the following Hasse diagram, where lower // types can be cast to higher types (but not vice versa): // -// exact ternary lpm range -// \_________ \ / _____/ +// exact ternary lpm range optional +// \_________ \ / _____/_________/ // bit // | // arbitrary_int @@ -66,6 +66,7 @@ bool StrictlyAboveInCastabilityOrder(const Type& left, const Type& right) { case Type::kTernary: case Type::kLpm: case Type::kRange: + case Type::kOptionalMatch: switch (right.type_case()) { case Type::kFixedUnsigned: return TypeBitwidth(left) == TypeBitwidth(right); @@ -94,7 +95,7 @@ absl::optional LeastUpperBound(const Type& left, const Type& right) { // While it is not true for partial orders in general that // LeastUpperBound(x,y) exists iff x >= y or y >= x, it is true for our // castability relation. - return {}; + return absl::nullopt; } // Mutates the input expression, wrapping it with a type_cast to the given type. @@ -108,6 +109,29 @@ void WrapWithCast(Expression* expr, Type type) { *expr = std::move(cast); } +// Mutates the input expression, wrapping it with a chain of zero or more type +// casts to convert it, possibly transitively, to the given target type. +absl::Status CastTransitivelyTo(Expression* expr, Type target_type) { + if (StrictlyAboveInCastabilityOrder(target_type, expr->type()) && + expr->type().type_case() == Type::kArbitraryInt) { + // Insert int ~~> bit cast. + Type fixed_unsigned; + auto bitwidth = TypeBitwidth(target_type).value_or(-1); + DCHECK_NE(bitwidth, -1) << "cannot cast to arbitrary-size type"; + fixed_unsigned.mutable_fixed_unsigned()->set_bitwidth(bitwidth); + WrapWithCast(expr, fixed_unsigned); + } + + if (StrictlyAboveInCastabilityOrder(target_type, expr->type()) && + expr->type().type_case() == Type::kFixedUnsigned) { + // Insert bit ~~> Exact/Ternary/LPM/Range/Optional cast. + WrapWithCast(expr, target_type); + } + + DCHECK_EQ(expr->type(), target_type) << "unification did not unify types"; + return absl::OkStatus(); +} + // Attempts to unify the types of the given expressions, returning the // resulting type if unification succeeds, or an InvalidArgument Status // otherwise. @@ -125,28 +149,10 @@ absl::StatusOr Unify(Expression* left, Expression* right) { LeastUpperBound(left->type(), right->type()); if (!least_upper_bound.has_value()) { return TypeError(left->start_location(), right->end_location()) - << "cannot unify types " << TypeName(left->type()) << " and " - << TypeName(right->type()); - } - for (Expression* expr : {left, right}) { - if (StrictlyAboveInCastabilityOrder(*least_upper_bound, expr->type()) && - expr->type().type_case() == Type::kArbitraryInt) { - // Insert int ~~> bit cast. - Type fixed_unsigned; - auto bitwidth = TypeBitwidth(*least_upper_bound).value_or(-1); - DCHECK_NE(bitwidth, -1) << "cannot cast to arbitrary-size type"; - fixed_unsigned.mutable_fixed_unsigned()->set_bitwidth(bitwidth); - WrapWithCast(expr, fixed_unsigned); - } - - if (StrictlyAboveInCastabilityOrder(*least_upper_bound, expr->type()) && - expr->type().type_case() == Type::kFixedUnsigned) { - WrapWithCast(expr, *least_upper_bound); - } - - DCHECK_EQ(*least_upper_bound, expr->type()) - << "unification did not unify types"; + << "cannot unify types " << left->type() << " and " << right->type(); } + RETURN_IF_ERROR(CastTransitivelyTo(left, *least_upper_bound)); + RETURN_IF_ERROR(CastTransitivelyTo(right, *least_upper_bound)); return *least_upper_bound; } @@ -164,6 +170,8 @@ const auto* const kFieldTypes = {std::make_tuple(Type::kLpm, "prefix_length"), Type::kArbitraryInt}, {std::make_tuple(Type::kRange, "low"), Type::kFixedUnsigned}, {std::make_tuple(Type::kRange, "high"), Type::kFixedUnsigned}, + {std::make_tuple(Type::kOptionalMatch, "value"), Type::kFixedUnsigned}, + {std::make_tuple(Type::kOptionalMatch, "mask"), Type::kFixedUnsigned}, }; absl::optional FieldTypeOfCompositeType(const Type& composite_type, @@ -175,7 +183,7 @@ absl::optional FieldTypeOfCompositeType(const Type& composite_type, absl::optional bitwidth = TypeBitwidth(composite_type); if (!bitwidth.has_value()) { LOG(DFATAL) << "expected composite type " << composite_type - << "to have bitwidth"; + << " to have bitwidth"; } SetTypeBitwidth(&field_type, bitwidth.value_or(-1)); return {field_type}; @@ -298,10 +306,11 @@ absl::Status InferAndCheckTypes(Expression* expr, const TableInfo& table_info) { return absl::OkStatus(); } - default: - return gutils::InternalErrorBuilder(GUTILS_LOC) - << "unknown expression case: " << expr->expression_case(); + case Expression::EXPRESSION_NOT_SET: + break; } + return TypeError(expr->start_location(), expr->end_location()) + << "unexpected expression: " << expr->DebugString(); } } // namespace p4_constraints diff --git a/p4_constraints/backend/type_checker_test.cc b/p4_constraints/backend/type_checker_test.cc index c5b7ef9..793718e 100644 --- a/p4_constraints/backend/type_checker_test.cc +++ b/p4_constraints/backend/type_checker_test.cc @@ -53,6 +53,8 @@ class InferAndCheckTypesTest : public ::testing::Test { const Type kTernary32 = ParseTextProtoOrDie("ternary { bitwidth: 32 }"); const Type kLpm32 = ParseTextProtoOrDie("lpm { bitwidth: 32 }"); const Type kRange32 = ParseTextProtoOrDie("range { bitwidth: 32 }"); + const Type kOptional32 = + ParseTextProtoOrDie("optional_match { bitwidth: 32 }"); const TableInfo kTableInfo{ 0, @@ -76,11 +78,11 @@ class InferAndCheckTypesTest : public ::testing::Test { TEST_F(InferAndCheckTypesTest, InvalidExpressions) { Expression expr = ParseTextProtoOrDie(""); ASSERT_THAT(InferAndCheckTypes(&expr, kTableInfo), - StatusIs(StatusCode::kInternal)); + StatusIs(StatusCode::kInvalidArgument)); expr = ParseTextProtoOrDie("boolean_negation {}"); ASSERT_THAT(InferAndCheckTypes(&expr, kTableInfo), - StatusIs(StatusCode::kInternal)); + StatusIs(StatusCode::kInvalidArgument)); expr = ParseTextProtoOrDie("type_cast {}"); ASSERT_THAT(InferAndCheckTypes(&expr, kTableInfo), @@ -88,7 +90,7 @@ TEST_F(InferAndCheckTypesTest, InvalidExpressions) { expr = ParseTextProtoOrDie("binary_expression {}"); ASSERT_THAT(InferAndCheckTypes(&expr, kTableInfo), - StatusIs(StatusCode::kInternal)); + StatusIs(StatusCode::kInvalidArgument)); } TEST_F(InferAndCheckTypesTest, BooleanConstant) { @@ -159,7 +161,7 @@ TEST_F(InferAndCheckTypesTest, BooleanNegationOfBooleansTypeChecks) { TEST_F(InferAndCheckTypesTest, BooleanNegationOfNonBooleansDoesNotTypeCheck) { Expression expr = ParseTextProtoOrDie("boolean_negation {}"); ASSERT_THAT(InferAndCheckTypes(&expr, kTableInfo), - StatusIs(StatusCode::kInternal)); + StatusIs(StatusCode::kInvalidArgument)); expr = ParseTextProtoOrDie(R"PROTO( boolean_negation { integer_constant: "0" } @@ -201,7 +203,7 @@ TEST_F(InferAndCheckTypesTest, ArithmeticNegationOfIntTypeChecks) { TEST_F(InferAndCheckTypesTest, ArithmeticNegationOfNonIntDoesNotTypeChecks) { Expression expr = ParseTextProtoOrDie("arithmetic_negation {}"); ASSERT_THAT(InferAndCheckTypes(&expr, kTableInfo), - StatusIs(StatusCode::kInternal)); + StatusIs(StatusCode::kInvalidArgument)); expr = ParseTextProtoOrDie(R"( arithmetic_negation { boolean_constant: true }