From ad161b26d6867db8524e8e4601f3567ea11cd185 Mon Sep 17 00:00:00 2001 From: Anton Korobeynikov Date: Wed, 17 Jul 2024 16:08:04 -0700 Subject: [PATCH] [NFC] Split TypeInference (#4814) * Factor our type type checking into a separate file. No functionality change Signed-off-by: Anton Korobeynikov * Factor out expression type checking into a separate file. No functionality change Signed-off-by: Anton Korobeynikov * Factor out type checking for statements. No functionality change. Signed-off-by: Anton Korobeynikov * Factor out declarations type checking. Some further cleanup here and there. No functionality change Signed-off-by: Anton Korobeynikov * Get rid of std::function Signed-off-by: Anton Korobeynikov * Move to proper place Signed-off-by: Anton Korobeynikov * Switch to better containers Signed-off-by: Anton Korobeynikov * Simplify includes Signed-off-by: Anton Korobeynikov * Do not forget to lint headers Signed-off-by: Anton Korobeynikov --------- Signed-off-by: Anton Korobeynikov --- frontends/CMakeLists.txt | 5 + .../typeChecking/constantTypeSubstitution.h | 83 + frontends/p4/typeChecking/typeCheckDecl.cpp | 317 ++ frontends/p4/typeChecking/typeCheckExpr.cpp | 2253 +++++++++++ frontends/p4/typeChecking/typeCheckStmt.cpp | 339 ++ frontends/p4/typeChecking/typeCheckTypes.cpp | 559 +++ frontends/p4/typeChecking/typeChecker.cpp | 3460 +---------------- frontends/p4/typeChecking/typeChecker.h | 6 +- frontends/p4/typeChecking/typeConstraints.h | 8 +- 9 files changed, 3566 insertions(+), 3464 deletions(-) create mode 100644 frontends/p4/typeChecking/constantTypeSubstitution.h create mode 100644 frontends/p4/typeChecking/typeCheckDecl.cpp create mode 100644 frontends/p4/typeChecking/typeCheckExpr.cpp create mode 100644 frontends/p4/typeChecking/typeCheckStmt.cpp create mode 100644 frontends/p4/typeChecking/typeCheckTypes.cpp diff --git a/frontends/CMakeLists.txt b/frontends/CMakeLists.txt index 892eed24d45..127edcff4b7 100644 --- a/frontends/CMakeLists.txt +++ b/frontends/CMakeLists.txt @@ -71,6 +71,10 @@ set (P4_FRONTEND_SRCS p4/typeChecking/bindVariables.cpp p4/typeChecking/syntacticEquivalence.cpp p4/typeChecking/typeChecker.cpp + p4/typeChecking/typeCheckDecl.cpp + p4/typeChecking/typeCheckExpr.cpp + p4/typeChecking/typeCheckStmt.cpp + p4/typeChecking/typeCheckTypes.cpp p4/typeChecking/typeConstraints.cpp p4/typeChecking/typeSubstitution.cpp p4/typeChecking/typeSubstitutionVisitor.cpp @@ -144,6 +148,7 @@ set (P4_FRONTEND_HDRS p4/ternaryBool.h p4/toP4/toP4.h p4/typeChecking/bindVariables.h + p4/typeChecking/constantTypeSubstitution.h p4/typeChecking/syntacticEquivalence.h p4/typeChecking/typeChecker.h p4/typeChecking/typeConstraints.h diff --git a/frontends/p4/typeChecking/constantTypeSubstitution.h b/frontends/p4/typeChecking/constantTypeSubstitution.h new file mode 100644 index 00000000000..1d18eece3b5 --- /dev/null +++ b/frontends/p4/typeChecking/constantTypeSubstitution.h @@ -0,0 +1,83 @@ +/* +Copyright 2013-present Barefoot Networks, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef FRONTENDS_P4_TYPECHECKING_CONSTANTTYPESUBSTITUTION_H_ +#define FRONTENDS_P4_TYPECHECKING_CONSTANTTYPESUBSTITUTION_H_ + +#include "frontends/common/resolveReferences/resolveReferences.h" +#include "ir/visitor.h" +#include "typeChecker.h" +#include "typeSubstitution.h" + +namespace P4 { + +// Used to set the type of Constants after type inference +class ConstantTypeSubstitution : public Transform, ResolutionContext { + TypeVariableSubstitution *subst; + TypeMap *typeMap; + TypeInference *tc; + + public: + ConstantTypeSubstitution(TypeVariableSubstitution *subst, TypeMap *typeMap, TypeInference *tc) + : subst(subst), typeMap(typeMap), tc(tc) { + CHECK_NULL(subst); + CHECK_NULL(typeMap); + CHECK_NULL(tc); + LOG3("ConstantTypeSubstitution " << subst); + } + + const IR::Node *postorder(IR::Constant *cst) override { + auto cstType = typeMap->getType(getOriginal(), true); + if (!cstType->is()) return cst; + auto repl = cstType; + while (repl->is()) { + auto next = subst->get(repl->to()); + BUG_CHECK(next != repl, "Cycle in substitutions: %1%", next); + if (!next) break; + repl = next; + } + if (repl != cstType) { + // We may replace a type variable with another one + LOG2("Inferred type " << repl << " for " << cst); + cst = new IR::Constant(cst->srcInfo, repl, cst->value, cst->base); + } else { + LOG2("No type inferred for " << cst << " repl is " << repl); + } + return cst; + } + + const IR::Expression *convert(const IR::Expression *expr, const Visitor::Context *ctxt) { + auto result = expr->apply(*this, ctxt)->to(); + if (result != expr && (::errorCount() == 0)) tc->learn(result, this, ctxt); + return result; + } + const IR::Vector *convert(const IR::Vector *vec, + const Visitor::Context *ctxt) { + auto result = vec->apply(*this, ctxt)->to>(); + if (result != vec) tc->learn(result, this, ctxt); + return result; + } + const IR::Vector *convert(const IR::Vector *vec, + const Visitor::Context *ctxt) { + auto result = vec->apply(*this, ctxt)->to>(); + if (result != vec) tc->learn(result, this, ctxt); + return result; + } +}; + +} // namespace P4 + +#endif // FRONTENDS_P4_TYPECHECKING_CONSTANTTYPESUBSTITUTION_H_ diff --git a/frontends/p4/typeChecking/typeCheckDecl.cpp b/frontends/p4/typeChecking/typeCheckDecl.cpp new file mode 100644 index 00000000000..f440d65c6fa --- /dev/null +++ b/frontends/p4/typeChecking/typeCheckDecl.cpp @@ -0,0 +1,317 @@ +/* +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "typeChecker.h" + +namespace P4 { + +using namespace literals; + +const IR::Node *TypeInference::postorder(IR::P4Table *table) { + currentActionList = nullptr; + if (done()) return table; + auto type = new IR::Type_Table(table); + setType(getOriginal(), type); + setType(table, type); + return table; +} + +bool TypeInference::checkParameters(const IR::ParameterList *paramList, bool forbidModules, + bool forbidPackage) const { + for (auto p : paramList->parameters) { + auto type = getType(p); + if (type == nullptr) return false; + if (auto ts = type->to()) type = ts->baseType; + if (forbidPackage && type->is()) { + typeError("%1%: parameter cannot be a package", p); + return false; + } + if (p->direction != IR::Direction::None && + (type->is() || type->is())) { + typeError("%1%: a parameter with type %2% cannot have a direction", p, type); + return false; + } + if ((forbidModules || p->direction != IR::Direction::None) && + (type->is() || type->is() || + type->is() || type->is())) { + typeError("%1%: parameter cannot have type %2%", p, type); + return false; + } + } + return true; +} + +const IR::ParameterList *TypeInference::canonicalizeParameters(const IR::ParameterList *params) { + if (params == nullptr) return params; + + bool changes = false; + auto vec = new IR::IndexedVector(); + for (auto p : *params->getEnumerator()) { + auto paramType = getTypeType(p->type); + if (paramType == nullptr) return nullptr; + BUG_CHECK(!paramType->is(), "%1%: Unexpected parameter type", paramType); + if (paramType != p->type) { + p = new IR::Parameter(p->srcInfo, p->name, p->annotations, p->direction, paramType, + p->defaultValue); + changes = true; + } + setType(p, paramType); + vec->push_back(p); + } + if (changes) + return new IR::ParameterList(params->srcInfo, *vec); + else + return params; +} + +const IR::Node *TypeInference::postorder(IR::P4Action *action) { + if (done()) return action; + auto pl = canonicalizeParameters(action->parameters); + if (pl == nullptr) return action; + if (!checkParameters(action->parameters, forbidModules, forbidPackages)) return action; + auto type = new IR::Type_Action(new IR::TypeParameters(), nullptr, pl); + + bool foundDirectionless = false; + for (auto p : action->parameters->parameters) { + auto ptype = getType(p); + BUG_CHECK(ptype, "%1%: parameter type missing when it was found previously", p); + if (ptype->is()) + typeError("%1%: Action parameters cannot have extern types", p->type); + if (p->direction == IR::Direction::None) + foundDirectionless = true; + else if (foundDirectionless) + typeError("%1%: direction-less action parameters have to be at the end", p); + } + setType(getOriginal(), type); + setType(action, type); + return action; +} + +const IR::Node *TypeInference::postorder(IR::Declaration_Variable *decl) { + if (done()) return decl; + auto type = getTypeType(decl->type); + if (type == nullptr) return decl; + + if (const IR::IMayBeGenericType *gt = type->to()) { + // Check that there are no unbound type parameters + if (!gt->getTypeParameters()->empty()) { + typeError("Unspecified type parameters for %1% in %2%", gt, decl); + return decl; + } + } + + const IR::Type *baseType = type; + if (auto sc = type->to()) baseType = sc->baseType; + if (baseType->is() || baseType->is() || + baseType->is() || baseType->is()) { + typeError("%1%: cannot declare variables of type '%2%' (consider using an instantiation)", + decl, type); + return decl; + } + + if (type->is() || type->is()) { + typeError("%1%: Cannot declare variables with type %2%", decl, type); + return decl; + } + + auto orig = getOriginal(); + if (decl->initializer != nullptr) { + auto init = assignment(decl, type, decl->initializer); + if (decl->initializer != init) { + auto declType = type->getP4Type(); + decl->type = declType; + decl->initializer = init; + LOG2("Created new declaration " << decl); + } + } + setType(decl, type); + setType(orig, type); + return decl; +} + +const IR::Node *TypeInference::postorder(IR::Declaration_Constant *decl) { + if (done()) return decl; + auto type = getTypeType(decl->type); + if (type == nullptr) return decl; + + if (type->is()) { + typeError("%1%: Cannot declare constants of extern types", decl->name); + return decl; + } + + if (!isCompileTimeConstant(decl->initializer)) + typeError("%1%: Cannot evaluate initializer to a compile-time constant", decl->initializer); + auto orig = getOriginal(); + auto newInit = assignment(decl, type, decl->initializer); + if (newInit != decl->initializer) + decl = new IR::Declaration_Constant(decl->srcInfo, decl->name, decl->annotations, + decl->type, newInit); + setType(decl, type); + setType(orig, type); + return decl; +} + +// Return true on success +bool TypeInference::checkAbstractMethods(const IR::Declaration_Instance *inst, + const IR::Type_Extern *type) { + // Make a list of the abstract methods + IR::NameMap virt; + for (auto m : type->methods) + if (m->isAbstract) virt.addUnique(m->name, m); + if (virt.size() == 0 && inst->initializer == nullptr) return true; + if (virt.size() == 0 && inst->initializer != nullptr) { + typeError("%1%: instance initializers for extern without abstract methods", + inst->initializer); + return false; + } else if (virt.size() != 0 && inst->initializer == nullptr) { + typeError("%1%: must declare abstract methods for %2%", inst, type); + return false; + } + + for (auto d : inst->initializer->components) { + if (auto *func = d->to()) { + LOG2("Type checking " << func); + if (func->type->typeParameters->size() != 0) { + typeError("%1%: abstract method implementations cannot be generic", func); + return false; + } + auto ftype = getType(func); + if (virt.find(func->name.name) == virt.end()) { + typeError("%1%: no matching abstract method in %2%", func, type); + return false; + } + auto meth = virt[func->name.name]; + auto methtype = getType(meth); + virt.erase(func->name.name); + auto tvs = + unify(inst, methtype, ftype, "Method '%1%' does not have the expected type '%2%'", + {func, methtype}); + if (tvs == nullptr) return false; + BUG_CHECK(errorCount() > 0 || tvs->isIdentity(), "%1%: expected no type variables", + tvs); + } + } + bool rv = true; + for (auto &vm : virt) { + if (!vm.second->annotations->getSingle("optional"_cs)) { + typeError("%1%: %2% abstract method not implemented", inst, vm.second); + rv = false; + } + } + return rv; +} + +const IR::Node *TypeInference::preorder(IR::Declaration_Instance *decl) { + // We need to control the order of the type-checking: we want to do first + // the declaration, and then typecheck the initializer if present. + if (done()) return decl; + visit(decl->type, "type"); + visit(decl->arguments, "arguments"); + visit(decl->annotations, "annotations"); + visit(decl->properties, "properties"); + + auto type = getTypeType(decl->type); + if (type == nullptr) { + prune(); + return decl; + } + auto orig = getOriginal(); + + auto simpleType = type; + if (auto *sc = type->to()) simpleType = sc->substituted; + + if (auto et = simpleType->to()) { + auto [newType, newArgs] = checkExternConstructor(decl, et, decl->arguments); + if (newArgs == nullptr) { + prune(); + return decl; + } + // type can be Type_Extern or Type_SpecializedCanonical. If it is already + // specialized, the type arguments were specified explicitly. + // Otherwise, we use the type received from checkExternConstructor, which + // has substituted the type variables with fresh ones. + if (type->is()) type = newType; + decl->arguments = newArgs; + setType(orig, type); + setType(decl, type); + + if (decl->initializer != nullptr) visit(decl->initializer); + // This will need the decl type to be already known + bool s = checkAbstractMethods(decl, et); + if (!s) { + prune(); + return decl; + } + } else if (simpleType->is()) { + if (decl->initializer != nullptr) { + typeError("%1%: initializers only allowed for extern instances", decl->initializer); + prune(); + return decl; + } + if (!simpleType->is() && (findContext() == nullptr)) { + ::error(ErrorType::ERR_INVALID, "%1%: cannot instantiate at top-level", decl); + return decl; + } + auto typeAndArgs = + containerInstantiation(decl, decl->arguments, simpleType->to()); + auto type = typeAndArgs.first; + auto args = typeAndArgs.second; + if (type == nullptr || args == nullptr) { + prune(); + return decl; + } + learn(type, this, getChildContext()); + if (args != decl->arguments) decl->arguments = args; + setType(decl, type); + setType(orig, type); + } else { + typeError("%1%: cannot allocate objects of type %2%", decl, type); + } + prune(); + return decl; +} + +const IR::Node *TypeInference::preorder(IR::Function *function) { + if (done()) return function; + visit(function->type); + auto type = getTypeType(function->type); + if (type == nullptr) return function; + setType(getOriginal(), type); + setType(function, type); + visit(function->body); + prune(); + return function; +} + +const IR::Node *TypeInference::postorder(IR::Method *method) { + if (done()) return method; + auto type = getTypeType(method->type); + if (type == nullptr) return method; + if (auto mtype = type->to()) { + if (mtype->returnType) { + if (auto gen = mtype->returnType->to()) { + if (gen->getTypeParameters()->size() != 0) { + typeError("%1%: no type parameters supplied for return generic type", + method->type->returnType); + return method; + } + } + } + } + setType(getOriginal(), type); + setType(method, type); + return method; +} + +} // namespace P4 diff --git a/frontends/p4/typeChecking/typeCheckExpr.cpp b/frontends/p4/typeChecking/typeCheckExpr.cpp new file mode 100644 index 00000000000..f7ecbd7a34c --- /dev/null +++ b/frontends/p4/typeChecking/typeCheckExpr.cpp @@ -0,0 +1,2253 @@ +/* +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "constantTypeSubstitution.h" +#include "frontends/p4/enumInstance.h" +#include "lib/algorithm.h" +#include "typeChecker.h" +#include "typeConstraints.h" + +namespace P4 { + +using namespace literals; + +const IR::Node *TypeInference::postorder(IR::Parameter *param) { + if (done()) return param; + const IR::Type *paramType = getTypeType(param->type); + if (paramType == nullptr) return param; + BUG_CHECK(!paramType->is(), "%1%: unexpected type", paramType); + + if (paramType->is() || paramType->is()) { + typeError("%1%: parameter cannot have type %2%", param, paramType); + return param; + } + + if (!readOnly && paramType->is()) { + // We only give these errors if we are no in 'readOnly' mode: + // this prevents giving a confusing error message to the user. + if (param->direction != IR::Direction::None) { + typeError("%1%: parameters with type %2% must be directionless", param, paramType); + return param; + } + if (findContext()) { + typeError("%1%: actions cannot have parameters with type %2%", param, paramType); + return param; + } + } + + // The parameter type cannot have free type variables + if (auto *gen = paramType->to()) { + auto tp = gen->getTypeParameters(); + if (!tp->empty()) { + typeError("Type parameters needed for %1%", param->name); + return param; + } + } + + if (param->defaultValue) { + if (!typeMap->isCompileTimeConstant(param->defaultValue)) + typeError("%1%: expression must be a compile-time constant", param->defaultValue); + } + + setType(getOriginal(), paramType); + setType(param, paramType); + return param; +} + +const IR::Node *TypeInference::postorder(IR::Constant *expression) { + if (done()) return expression; + auto type = getTypeType(expression->type); + if (type == nullptr) return expression; + setType(getOriginal(), type); + setType(expression, type); + setCompileTimeConstant(getOriginal()); + setCompileTimeConstant(expression); + return expression; +} + +const IR::Node *TypeInference::postorder(IR::StringLiteral *expression) { + if (done()) return expression; + setType(getOriginal(), IR::Type_String::get()); + setType(expression, IR::Type_String::get()); + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + return expression; +} + +const IR::Node *TypeInference::postorder(IR::BoolLiteral *expression) { + if (done()) return expression; + setType(getOriginal(), IR::Type_Boolean::get()); + setType(expression, IR::Type_Boolean::get()); + setCompileTimeConstant(getOriginal()); + setCompileTimeConstant(expression); + return expression; +} + +bool TypeInference::containsActionEnum(const IR::Type *type) const { + if (auto st = type->to()) { + if (auto field = st->getField(IR::Type_Table::action_run)) { + auto ft = getTypeType(field->type); + if (ft->is()) return true; + } + } + return false; +} + +// Returns false on error +bool TypeInference::compare(const IR::Node *errorPosition, const IR::Type *ltype, + const IR::Type *rtype, Comparison *compare) { + if (ltype->is() || rtype->is()) { + // Actions return Type_Action instead of void. + typeError("%1% and %2% cannot be compared", compare->left, compare->right); + return false; + } + if (ltype->is() || rtype->is()) { + typeError("%1% and %2%: tables cannot be compared", compare->left, compare->right); + return false; + } + if (ltype->is() || rtype->is()) { + typeError("%1% and %2%: externs cannot be compared", compare->left, compare->right); + return false; + } + if (containsActionEnum(ltype) || containsActionEnum(rtype)) { + typeError("%1% and %2%: table application results cannot be compared", compare->left, + compare->right); + return false; + } + + bool defined = false; + if (typeMap->equivalent(ltype, rtype) && + (!ltype->is() && !ltype->is()) && + !ltype->to()) { + defined = true; + } else if (ltype->is() && rtype->is() && + typeMap->equivalent(ltype, rtype)) { + defined = true; + } else if (ltype->is() && rtype->is()) { + auto tvs = unify(errorPosition, ltype, rtype); + if (tvs == nullptr) return false; + if (!tvs->isIdentity()) { + ConstantTypeSubstitution cts(tvs, typeMap, this); + compare->left = cts.convert(compare->left, getChildContext()); + compare->right = cts.convert(compare->right, getChildContext()); + } + defined = true; + } else if (auto se = rtype->to()) { + // This can only happen in a switch statement, other comparisons + // eliminate SerEnums before calling here. + if (typeMap->equivalent(ltype, se->type)) defined = true; + } else { + auto ls = ltype->to(); + auto rs = rtype->to(); + if (ls != nullptr || rs != nullptr) { + if (ls != nullptr && rs != nullptr) { + typeError("%1%: cannot compare structure-valued expressions with unknown types", + errorPosition); + return false; + } + + bool lcst = isCompileTimeConstant(compare->left); + bool rcst = isCompileTimeConstant(compare->right); + TypeVariableSubstitution *tvs; + if (ls == nullptr) { + tvs = unify(errorPosition, ltype, rtype); + } else { + tvs = unify(errorPosition, rtype, ltype); + } + if (tvs == nullptr) return false; + if (!tvs->isIdentity()) { + ConstantTypeSubstitution cts(tvs, typeMap, this); + compare->left = cts.convert(compare->left, getChildContext()); + compare->right = cts.convert(compare->right, getChildContext()); + } + + if (ls != nullptr) { + auto l = compare->left->to(); + CHECK_NULL(l); // struct initializers are the only expressions that can + // have StructUnknown types + BUG_CHECK(rtype->is(), "%1%: expected a struct", rtype); + auto type = new IR::Type_Name(rtype->to()->name); + compare->left = + new IR::StructExpression(compare->left->srcInfo, type, type, l->components); + setType(compare->left, rtype); + if (lcst) setCompileTimeConstant(compare->left); + } else { + auto r = compare->right->to(); + CHECK_NULL(r); // struct initializers are the only expressions that can + // have StructUnknown types + BUG_CHECK(ltype->is(), "%1%: expected a struct", ltype); + auto type = new IR::Type_Name(ltype->to()->name); + compare->right = + new IR::StructExpression(compare->right->srcInfo, type, type, r->components); + setType(compare->right, rtype); + if (rcst) setCompileTimeConstant(compare->right); + } + defined = true; + } + + // comparison between structs and list expressions is allowed + if ((ltype->is() && rtype->is()) || + (ltype->is() && rtype->is())) { + if (!ltype->is()) { + // swap + auto type = ltype; + ltype = rtype; + rtype = type; + } + + auto tvs = unify(errorPosition, ltype, rtype); + if (tvs == nullptr) return false; + if (!tvs->isIdentity()) { + ConstantTypeSubstitution cts(tvs, typeMap, this); + compare->left = cts.convert(compare->left, getChildContext()); + compare->right = cts.convert(compare->right, getChildContext()); + } + defined = true; + } + } + + if (!defined) { + typeError("'%1%' with type '%2%' cannot be compared to '%3%' with type '%4%'", + compare->left, ltype, compare->right, rtype); + return false; + } + return true; +} + +const IR::Node *TypeInference::postorder(IR::Operation_Relation *expression) { + if (done()) return expression; + auto ltype = getType(expression->left); + auto rtype = getType(expression->right); + if (ltype == nullptr || rtype == nullptr) return expression; + + bool equTest = expression->is() || expression->is(); + if (auto l = ltype->to()) ltype = getTypeType(l->type); + if (auto r = rtype->to()) rtype = getTypeType(r->type); + BUG_CHECK(ltype && rtype, "Invalid Type_SerEnum/getTypeType"); + + if (ltype->is() && rtype->is()) { + // This can happen because we are replacing some constant functions with + // constants during type checking + auto result = constantFold(expression); + setType(getOriginal(), IR::Type_Boolean::get()); + setCompileTimeConstant(result); + setCompileTimeConstant(getOriginal()); + return result; + } else if (ltype->is() && rtype->is()) { + auto e = expression->clone(); + e->left = new IR::Cast(e->left->srcInfo, rtype, e->left); + setType(e->left, rtype); + ltype = rtype; + expression = e; + } else if (rtype->is() && ltype->is()) { + auto e = expression->clone(); + e->right = new IR::Cast(e->right->srcInfo, ltype, e->right); + setType(e->right, ltype); + rtype = ltype; + expression = e; + } + + if (equTest) { + Comparison c; + c.left = expression->left; + c.right = expression->right; + auto b = compare(expression, ltype, rtype, &c); + if (!b) return expression; + expression->left = c.left; + expression->right = c.right; + } else { + if (!ltype->is() || !rtype->is() || !(ltype->equiv(*rtype))) { + typeError("%1%: not defined on %2% and %3%", expression, ltype->toString(), + rtype->toString()); + return expression; + } + } + setType(getOriginal(), IR::Type_Boolean::get()); + setType(expression, IR::Type_Boolean::get()); + if (isCompileTimeConstant(expression->left) && isCompileTimeConstant(expression->right)) { + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + } + return expression; +} + +const IR::Node *TypeInference::postorder(IR::Concat *expression) { + if (done()) return expression; + auto ltype = getType(expression->left); + auto rtype = getType(expression->right); + if (ltype == nullptr || rtype == nullptr) return expression; + + if (ltype->is()) { + typeError("Please specify a width for the operand %1% of a concatenation", + expression->left); + return expression; + } + if (rtype->is()) { + typeError("Please specify a width for the operand %1% of a concatenation", + expression->right); + return expression; + } + + bool castLeft = false; + bool castRight = false; + if (auto se = ltype->to()) { + ltype = getTypeType(se->type); + castLeft = true; + } + if (auto se = rtype->to()) { + rtype = getTypeType(se->type); + castRight = true; + } + if (ltype == nullptr || rtype == nullptr) { + // getTypeType should have already taken care of the error message + return expression; + } + if (!ltype->is() || !rtype->is()) { + typeError("%1%: Concatenation not defined on %2% and %3%", expression, ltype->toString(), + rtype->toString()); + return expression; + } + auto bl = ltype->to(); + auto br = rtype->to(); + const IR::Type *resultType = IR::Type_Bits::get(bl->size + br->size, bl->isSigned); + + if (castLeft) { + auto e = expression->clone(); + e->left = new IR::Cast(e->left->srcInfo, bl, e->left); + if (isCompileTimeConstant(expression->left)) setCompileTimeConstant(e->left); + setType(e->left, ltype); + expression = e; + } + if (castRight) { + auto e = expression->clone(); + e->right = new IR::Cast(e->right->srcInfo, br, e->right); + if (isCompileTimeConstant(expression->right)) setCompileTimeConstant(e->right); + setType(e->right, rtype); + expression = e; + } + + resultType = canonicalize(resultType); + if (resultType != nullptr) { + setType(getOriginal(), resultType); + setType(expression, resultType); + if (isCompileTimeConstant(expression->left) && isCompileTimeConstant(expression->right)) { + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + } + } + return expression; +} + +/** + * compute the type of table keys. + * Used to typecheck pre-defined entries. + */ +const IR::Node *TypeInference::postorder(IR::Key *key) { + // compute the type and store it in typeMap + auto keyTuple = new IR::Type_Tuple; + for (auto ke : key->keyElements) { + auto kt = typeMap->getType(ke->expression); + if (kt == nullptr) { + LOG2("Bailing out for " << dbp(ke)); + return key; + } + keyTuple->components.push_back(kt); + } + LOG2("Setting key type to " << dbp(keyTuple)); + setType(key, keyTuple); + setType(getOriginal(), keyTuple); + return key; +} + +/** + * typecheck a table initializer entry list + */ +const IR::Node *TypeInference::preorder(IR::EntriesList *el) { + if (done()) return el; + auto table = findContext(); + BUG_CHECK(table != nullptr, "%1% entries not within a table", el); + const IR::Key *key = table->getKey(); + if (key == nullptr) { + if (el->size() != 0) + typeError("Entries cannot be specified for a table with no key %1%", table); + prune(); + return el; + } + auto keyTuple = typeMap->getType(key); // direct typeMap call to skip checks + if (keyTuple == nullptr) { + // The keys have to be before the entries list. If they are not, + // at this point they have not yet been type-checked. + if (key->srcInfo.isValid() && el->srcInfo.isValid() && key->srcInfo >= el->srcInfo) { + typeError("%1%: Entries list must be after table key %2%", el, key); + prune(); + return el; + } + // otherwise the type-checking of the keys must have failed + } + return el; +} + +/** + * typecheck a table initializer entry + * + * The invariants are: + * - table keys and entry keys must have the same length + * - entry key elements must be compile time constants + * - actionRefs in entries must be in the action list + * - table keys must have been type checked before entries + * + * Moreover, the EntriesList visitor should have checked for the table + * invariants. + */ +const IR::Node *TypeInference::postorder(IR::Entry *entry) { + if (done()) return entry; + auto table = findContext(); + if (table == nullptr) return entry; + const IR::Key *key = table->getKey(); + if (key == nullptr) return entry; + auto keyTuple = getType(key); + if (keyTuple == nullptr) return entry; + + auto entryKeyType = getType(entry->keys); + if (entryKeyType == nullptr) return entry; + if (auto ts = entryKeyType->to()) entryKeyType = ts->elementType; + if (entry->singleton) { + if (auto tl = entryKeyType->to()) { + // An entry of _ does not have type Tuple, but rather Type_Dontcare + if (tl->getSize() == 1 && tl->components.at(0)->is()) + entryKeyType = tl->components.at(0); + } + } + + auto keyset = entry->getKeys(); + if (keyset == nullptr || !(keyset->is())) { + typeError("%1%: key expression must be tuple", keyset); + return entry; + } + if (keyset->components.size() < key->keyElements.size()) { + typeError("%1%: Size of entry keyset must match the table key set size", keyset); + return entry; + } + + bool nonConstantKeys = false; + for (auto ke : keyset->components) + if (!isCompileTimeConstant(ke)) { + typeError("Key entry must be a compile time constant: %1%", ke); + nonConstantKeys = true; + } + if (nonConstantKeys) return entry; + + if (entry->priority && !isCompileTimeConstant(entry->priority)) { + typeError("Entry priority must be a compile time constant: %1%", entry->priority); + return entry; + } + + TypeVariableSubstitution *tvs = + unifyCast(entry, keyTuple, entryKeyType, + "Table entry has type '%1%' which is not the expected type '%2%'", + {keyTuple, entryKeyType}); + if (tvs == nullptr) return entry; + ConstantTypeSubstitution cts(tvs, typeMap, this); + auto ks = cts.convert(keyset, getChildContext()); + if (::errorCount() > 0) return entry; + + if (ks != keyset) + entry = new IR::Entry(entry->srcInfo, entry->annotations, entry->isConst, entry->priority, + ks->to(), entry->action, entry->singleton); + + auto actionRef = entry->getAction(); + auto ale = validateActionInitializer(actionRef); + if (ale != nullptr) { + auto anno = ale->getAnnotation(IR::Annotation::defaultOnlyAnnotation); + if (anno != nullptr) { + typeError("%1%: Action marked with %2% used in table", entry, + IR::Annotation::defaultOnlyAnnotation); + return entry; + } + } + return entry; +} + +const IR::Node *TypeInference::postorder(IR::ListExpression *expression) { + if (done()) return expression; + bool constant = true; + auto components = new IR::Vector(); + for (auto c : expression->components) { + if (!isCompileTimeConstant(c)) constant = false; + auto type = getType(c); + if (type == nullptr) return expression; + components->push_back(type); + } + + auto tupleType = new IR::Type_List(expression->srcInfo, *components); + auto type = canonicalize(tupleType); + if (type == nullptr) return expression; + setType(getOriginal(), type); + setType(expression, type); + if (constant) { + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + } + return expression; +} + +const IR::Node *TypeInference::postorder(IR::Invalid *expression) { + if (done()) return expression; + auto unk = IR::Type_Unknown::get(); + setType(expression, unk); + setType(getOriginal(), unk); + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + return expression; +} + +const IR::Node *TypeInference::postorder(IR::InvalidHeader *expression) { + if (done()) return expression; + auto type = getTypeType(expression->headerType); + auto concreteType = type; + if (auto ts = concreteType->to()) concreteType = ts->substituted; + if (!concreteType->is()) { + typeError("%1%: invalid header expression has a non-header type `%2%`", expression, type); + return expression; + } + setType(expression, type); + setType(getOriginal(), type); + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + return expression; +} + +const IR::Node *TypeInference::postorder(IR::InvalidHeaderUnion *expression) { + if (done()) return expression; + auto type = getTypeType(expression->headerUnionType); + auto concreteType = type; + if (auto ts = concreteType->to()) concreteType = ts->substituted; + if (!concreteType->is()) { + typeError("%1%: does not have a header_union type `%2%`", expression, type); + return expression; + } + setType(expression, type); + setType(getOriginal(), type); + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + return expression; +} + +const IR::Node *TypeInference::postorder(IR::P4ListExpression *expression) { + if (done()) return expression; + bool constant = true; + auto elementType = getTypeType(expression->elementType); + auto vec = new IR::Vector(); + bool changed = false; + for (auto c : expression->components) { + if (!isCompileTimeConstant(c)) constant = false; + auto type = getType(c); + if (type == nullptr) return expression; + auto tvs = unify(expression, elementType, type, + "Vector element type '%1%' does not match expected type '%2%'", + {type, elementType}); + if (tvs == nullptr) return expression; + if (!tvs->isIdentity()) { + ConstantTypeSubstitution cts(tvs, typeMap, this); + auto converted = cts.convert(c, getChildContext()); + vec->push_back(converted); + changed = changed || converted != c; + } else { + vec->push_back(c); + } + } + + if (changed) + expression = new IR::P4ListExpression(expression->srcInfo, *vec, elementType->getP4Type()); + auto type = new IR::Type_P4List(expression->srcInfo, elementType); + setType(getOriginal(), type); + setType(expression, type); + if (constant) { + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + } + return expression; +} + +const IR::Node *TypeInference::postorder(IR::HeaderStackExpression *expression) { + if (done()) return expression; + bool constant = true; + auto stackType = getTypeType(expression->headerStackType); + if (auto st = stackType->to()) { + auto elementType = st->elementType; + auto vec = new IR::Vector(); + bool changed = false; + if (expression->size() != st->getSize()) { + typeError("%1%: number of initializers %2% has to match stack size %3%", expression, + expression->size(), st->getSize()); + return expression; + } + for (auto c : expression->components) { + if (!isCompileTimeConstant(c)) constant = false; + auto type = getType(c); + if (type == nullptr) return expression; + auto tvs = unify(expression, elementType, type, + "Stack element type '%1%' does not match expected type '%2%'", + {type, elementType}); + if (tvs == nullptr) return expression; + if (!tvs->isIdentity()) { + ConstantTypeSubstitution cts(tvs, typeMap, this); + auto converted = cts.convert(c, getChildContext()); + vec->push_back(converted); + changed = true; + } else { + vec->push_back(c); + } + if (changed) + expression = new IR::HeaderStackExpression(expression->srcInfo, *vec, stackType); + } + } else { + typeError("%1%: header stack expression has an incorrect type `%2%`", expression, + stackType); + return expression; + } + + setType(getOriginal(), stackType); + setType(expression, stackType); + if (constant) { + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + } + return expression; +} + +const IR::Node *TypeInference::postorder(IR::StructExpression *expression) { + if (done()) return expression; + bool constant = true; + auto components = new IR::IndexedVector(); + for (auto c : expression->components) { + if (!isCompileTimeConstant(c->expression)) constant = false; + auto type = getType(c->expression); + if (type == nullptr) return expression; + components->push_back(new IR::StructField(c->name, type)); + } + + // This is the type inferred by looking at the fields. + const IR::Type *structType = + new IR::Type_UnknownStruct(expression->srcInfo, "unknown struct", *components); + structType = canonicalize(structType); + + const IR::Expression *result = expression; + if (expression->structType != nullptr) { + // We know the exact type of the initializer + auto desired = getTypeType(expression->structType); + if (desired == nullptr) return expression; + auto tvs = unify(expression, desired, structType, + "Initializer type '%1%' does not match expected type '%2%'", + {structType, desired}); + if (tvs == nullptr) return expression; + if (!tvs->isIdentity()) { + ConstantTypeSubstitution cts(tvs, typeMap, this); + result = cts.convert(expression, getChildContext()); + } + structType = desired; + } + setType(getOriginal(), structType); + setType(expression, structType); + if (constant) { + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + } + return result; +} + +const IR::Node *TypeInference::postorder(IR::ArrayIndex *expression) { + if (done()) return expression; + auto ltype = getType(expression->left); + auto rtype = getType(expression->right); + if (ltype == nullptr || rtype == nullptr) return expression; + auto hst = ltype->to(); + + int index = -1; + if (auto cst = expression->right->to()) { + if (hst && checkArrays && !cst->fitsInt()) { + typeError("Index too large: %1%", cst); + return expression; + } + index = cst->asInt(); + if (hst && checkArrays && index < 0) { + typeError("%1%: Negative array index %2%", expression, cst); + return expression; + } + } + // if index is negative here it means it's not a constant + + if ((index < 0) && !rtype->is() && !rtype->is() && + !rtype->is()) { + typeError("Array index %1% must be an integer, but it has type %2%", expression->right, + rtype->toString()); + return expression; + } + + const IR::Type *type = nullptr; + if (hst) { + if (checkArrays && hst->sizeKnown()) { + int size = hst->getSize(); + if (index >= 0 && index >= size) { + typeError("Array index %1% larger or equal to array size %2%", expression->right, + hst->size); + return expression; + } + } + type = hst->elementType; + } else if (auto tup = ltype->to()) { + if (index < 0) { + typeError("Tuple index %1% must be constant", expression->right); + return expression; + } + if (static_cast(index) >= tup->getSize()) { + typeError("Tuple index %1% larger than tuple size %2%", expression->right, + tup->getSize()); + return expression; + } + type = tup->components.at(index); + if (isCompileTimeConstant(expression->left)) { + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + } + } else { + typeError("Indexing %1% applied to non-array and non-tuple type %2%", expression, + ltype->toString()); + return expression; + } + if (isLeftValue(expression->left)) { + setLeftValue(expression); + setLeftValue(getOriginal()); + } + setType(getOriginal(), type); + setType(expression, type); + return expression; +} + +const IR::Node *TypeInference::binaryBool(const IR::Operation_Binary *expression) { + if (done()) return expression; + auto ltype = getType(expression->left); + auto rtype = getType(expression->right); + if (ltype == nullptr || rtype == nullptr) return expression; + + if (!ltype->is() || !rtype->is()) { + typeError("%1%: not defined on %2% and %3%", expression, ltype->toString(), + rtype->toString()); + return expression; + } + setType(getOriginal(), IR::Type_Boolean::get()); + setType(expression, IR::Type_Boolean::get()); + if (isCompileTimeConstant(expression->left) && isCompileTimeConstant(expression->right)) { + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + } + return expression; +} + +const IR::Node *TypeInference::binaryArith(const IR::Operation_Binary *expression) { + if (done()) return expression; + auto ltype = getType(expression->left); + auto rtype = getType(expression->right); + if (ltype == nullptr || rtype == nullptr) return expression; + bool castLeft = false; + bool castRight = false; + + if (auto se = ltype->to()) { + ltype = getTypeType(se->type); + castLeft = true; + } + if (auto se = rtype->to()) { + rtype = getTypeType(se->type); + castRight = true; + } + BUG_CHECK(ltype && rtype, "Invalid Type_SerEnum/getTypeType"); + + const IR::Type_Bits *bl = ltype->to(); + const IR::Type_Bits *br = rtype->to(); + if (bl == nullptr && !ltype->is()) { + typeError("%1%: cannot be applied to expression '%2%' with type '%3%'", + expression->getStringOp(), expression->left, ltype->toString()); + return expression; + } else if (br == nullptr && !rtype->is()) { + typeError("%1%: cannot be applied to expression '%2%' with type '%3%'", + expression->getStringOp(), expression->right, rtype->toString()); + return expression; + } else if (ltype->is() && rtype->is()) { + auto t = IR::Type_InfInt::get(); + auto result = constantFold(expression); + setType(getOriginal(), t); + setCompileTimeConstant(result); + setCompileTimeConstant(getOriginal()); + return result; + } + + const IR::Type *resultType = ltype; + if (bl != nullptr && br != nullptr) { + if (bl->size != br->size) { + typeError("%1%: Cannot operate on values with different widths %2% and %3%", expression, + bl->size, br->size); + return expression; + } + if (bl->isSigned != br->isSigned) { + typeError("%1%: Cannot operate on values with different signs", expression); + return expression; + } + } + if ((bl == nullptr && br != nullptr) || castLeft) { + // must insert cast on the left + auto leftResultType = br; + if (castLeft && !br) leftResultType = bl; + auto e = expression->clone(); + e->left = new IR::Cast(e->left->srcInfo, leftResultType, e->left); + setType(e->left, leftResultType); + if (isCompileTimeConstant(expression->left)) { + e->left = constantFold(e->left); + setCompileTimeConstant(e->left); + } + expression = e; + resultType = leftResultType; + } + if ((bl != nullptr && br == nullptr) || castRight) { + auto e = expression->clone(); + auto rightResultType = bl; + if (castRight && !bl) rightResultType = br; + e->right = new IR::Cast(e->right->srcInfo, rightResultType, e->right); + setType(e->right, rightResultType); + if (isCompileTimeConstant(expression->right)) { + e->right = constantFold(e->right); + setCompileTimeConstant(e->right); + } + expression = e; + resultType = rightResultType; + } + + setType(getOriginal(), resultType); + setType(expression, resultType); + if (isCompileTimeConstant(expression->left) && isCompileTimeConstant(expression->right)) { + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + } + return expression; +} + +const IR::Node *TypeInference::unsBinaryArith(const IR::Operation_Binary *expression) { + if (done()) return expression; + auto ltype = getType(expression->left); + auto rtype = getType(expression->right); + if (ltype == nullptr || rtype == nullptr) return expression; + + if (auto se = ltype->to()) ltype = getTypeType(se->type); + if (auto se = rtype->to()) rtype = getTypeType(se->type); + + const IR::Type_Bits *bl = ltype->to(); + if (bl != nullptr && bl->isSigned) { + typeError("%1%: Cannot operate on signed values", expression); + return expression; + } + const IR::Type_Bits *br = rtype->to(); + if (br != nullptr && br->isSigned) { + typeError("%1%: Cannot operate on signed values", expression); + return expression; + } + + auto cleft = expression->left->to(); + if (cleft != nullptr) { + if (cleft->value < 0) { + typeError("%1%: not defined on negative numbers", expression); + return expression; + } + } + auto cright = expression->right->to(); + if (cright != nullptr) { + if (cright->value < 0) { + typeError("%1%: not defined on negative numbers", expression); + return expression; + } + } + + if (isCompileTimeConstant(expression->left) && isCompileTimeConstant(expression->right)) { + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + } + return binaryArith(expression); +} + +const IR::Node *TypeInference::shift(const IR::Operation_Binary *expression) { + if (done()) return expression; + auto ltype = getType(expression->left); + auto rtype = getType(expression->right); + if (ltype == nullptr || rtype == nullptr) return expression; + + if (auto se = ltype->to()) ltype = getTypeType(se->type); + if (ltype == nullptr) { + // getTypeType should have already taken care of the error message + return expression; + } + auto lt = ltype->to(); + if (auto cst = expression->right->to()) { + if (!cst->fitsInt()) { + typeError("Shift amount too large: %1%", cst); + return expression; + } + int shift = cst->asInt(); + if (shift < 0) { + typeError("%1%: Negative shift amount %2%", expression, cst); + return expression; + } + if (lt != nullptr && shift >= lt->size) + warn(ErrorType::WARN_OVERFLOW, "%1%: shifting value with %2% bits by %3%", expression, + lt->size, shift); + // If the amount is signed but positive, make it unsigned + if (auto bt = rtype->to()) { + if (bt->isSigned) { + rtype = IR::Type_Bits::get(rtype->srcInfo, bt->width_bits(), false); + auto amt = new IR::Constant(cst->srcInfo, rtype, cst->value, cst->base); + if (expression->is()) { + expression = new IR::Shl(expression->srcInfo, expression->left, amt); + } else { + expression = new IR::Shr(expression->srcInfo, expression->left, amt); + } + setCompileTimeConstant(expression->right); + setType(expression->right, rtype); + } + } + } + + if (rtype->is() && rtype->to()->isSigned) { + typeError("%1%: Shift amount must be an unsigned number", expression->right); + return expression; + } + + if (!lt && !ltype->is()) { + typeError("%1% left operand of shift must be a numeric type, not %2%", expression, + ltype->toString()); + return expression; + } + + if (ltype->is() && !rtype->is() && + !isCompileTimeConstant(expression->right)) { + typeError( + "%1%: shift result type is arbitrary-precision int, but right operand is not constant; " + "width of left operand of shift needs to be specified or both operands need to be " + "constant", + expression); + return expression; + } + + setType(expression, ltype); + setType(getOriginal(), ltype); + if (isCompileTimeConstant(expression->left) && isCompileTimeConstant(expression->right)) { + auto result = constantFold(expression); + setCompileTimeConstant(result); + setCompileTimeConstant(getOriginal()); + return result; + } + return expression; +} + +// Handle .. and &&& +const IR::Node *TypeInference::typeSet(const IR::Operation_Binary *expression) { + if (done()) return expression; + auto ltype = getType(expression->left); + auto rtype = getType(expression->right); + if (ltype == nullptr || rtype == nullptr) return expression; + + auto leftType = ltype; // save original type + if (auto se = ltype->to()) ltype = getTypeType(se->type); + if (auto se = rtype->to()) rtype = getTypeType(se->type); + BUG_CHECK(ltype && rtype, "Invalid Type_SerEnum/getTypeType"); + + // The following section is very similar to "binaryArith()" above + const IR::Type_Bits *bl = ltype->to(); + const IR::Type_Bits *br = rtype->to(); + if (bl == nullptr && !ltype->is()) { + typeError("%1%: cannot be applied to expression '%2%' with type '%3%'", + expression->getStringOp(), expression->left, ltype->toString()); + return expression; + } else if (br == nullptr && !rtype->is()) { + typeError("%1%: cannot be applied to expression '%2%' with type '%3%'", + expression->getStringOp(), expression->right, rtype->toString()); + return expression; + } + + const IR::Type *sameType = leftType; + if (bl != nullptr && br != nullptr) { + if (!typeMap->equivalent(bl, br)) { + typeError("%1%: Cannot operate on values with different types %2% and %3%", expression, + bl->toString(), br->toString()); + return expression; + } + } else if (bl == nullptr && br != nullptr) { + auto e = expression->clone(); + e->left = new IR::Cast(e->left->srcInfo, rtype, e->left); + setCompileTimeConstant(e->left); + expression = e; + sameType = rtype; + setType(e->left, sameType); + } else if (bl != nullptr && br == nullptr) { + auto e = expression->clone(); + e->right = new IR::Cast(e->right->srcInfo, ltype, e->right); + setCompileTimeConstant(e->right); + expression = e; + setType(e->right, ltype); + sameType = leftType; // Not ltype: SerEnum &&& Bit is Set + } else { + // both are InfInt: use same exact type for both sides, so it is properly + // set after unification + // FIXME -- the below is obviously wrong and just serves to tweak when precisely + // the type will be inferred -- papering over bugs elsewhere in typechecking, + // avoiding the BUG_CHECK(!readOnly... in end_apply/apply_visitor above. + // (maybe just need learner->readOnly = false in TypeInference::learn above?) + auto r = expression->right->clone(); + auto e = expression->clone(); + if (isCompileTimeConstant(expression->right)) setCompileTimeConstant(r); + e->right = r; + expression = e; + setType(r, sameType); + } + + auto resultType = new IR::Type_Set(sameType->srcInfo, sameType); + typeMap->setType(expression, resultType); + typeMap->setType(getOriginal(), resultType); + + if (isCompileTimeConstant(expression->left) && isCompileTimeConstant(expression->right)) { + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + } + + return expression; +} + +const IR::Node *TypeInference::postorder(IR::LNot *expression) { + if (done()) return expression; + auto type = getType(expression->expr); + if (type == nullptr) return expression; + if (!(*type == *IR::Type_Boolean::get())) { + typeError("Cannot apply %1% to value %2% of type %3%", expression->getStringOp(), + expression->expr, type->toString()); + } else { + setType(expression, IR::Type_Boolean::get()); + setType(getOriginal(), IR::Type_Boolean::get()); + } + if (isCompileTimeConstant(expression->expr)) { + auto result = constantFold(expression); + setCompileTimeConstant(result); + setCompileTimeConstant(getOriginal()); + return result; + } + return expression; +} + +const IR::Node *TypeInference::postorder(IR::Neg *expression) { + if (done()) return expression; + auto type = getType(expression->expr); + if (type == nullptr) return expression; + + if (auto se = type->to()) type = getTypeType(se->type); + BUG_CHECK(type, "Invalid Type_SerEnum/getTypeType"); + + if (type->is()) { + setType(getOriginal(), type); + setType(expression, type); + } else if (type->is()) { + setType(getOriginal(), type); + setType(expression, type); + } else { + typeError("Cannot apply %1% to value %2% of type %3%", expression->getStringOp(), + expression->expr, type->toString()); + } + if (isCompileTimeConstant(expression->expr)) { + auto result = constantFold(expression); + setCompileTimeConstant(result); + setCompileTimeConstant(getOriginal()); + return result; + } + return expression; +} + +const IR::Node *TypeInference::postorder(IR::UPlus *expression) { + if (done()) return expression; + auto type = getType(expression->expr); + if (type == nullptr) return expression; + + if (auto se = type->to()) type = getTypeType(se->type); + BUG_CHECK(type, "Invalid Type_SerEnum/getTypeType"); + + if (type->is()) { + setType(getOriginal(), type); + setType(expression, type); + } else if (type->is()) { + setType(getOriginal(), type); + setType(expression, type); + } else { + typeError("Cannot apply %1% to value %2% of type %3%", expression->getStringOp(), + expression->expr, type->toString()); + } + if (isCompileTimeConstant(expression->expr)) { + auto result = constantFold(expression); + setCompileTimeConstant(result); + setCompileTimeConstant(getOriginal()); + return result; + } + return expression; +} + +const IR::Node *TypeInference::postorder(IR::Cmpl *expression) { + if (done()) return expression; + auto type = getType(expression->expr); + if (type == nullptr) return expression; + + if (auto se = type->to()) type = getTypeType(se->type); + BUG_CHECK(type, "Invalid Type_SerEnum/getTypeType"); + + if (type->is()) { + typeError("'%1%' cannot be applied to an operand with an unknown width"); + } else if (type->is()) { + setType(getOriginal(), type); + setType(expression, type); + } else { + typeError("Cannot apply operation '%1%' to expression '%2%' with type '%3%'", + expression->getStringOp(), expression->expr, type->toString()); + } + if (isCompileTimeConstant(expression->expr)) { + auto result = constantFold(expression); + setCompileTimeConstant(result); + setCompileTimeConstant(getOriginal()); + return result; + } + return expression; +} + +const IR::Node *TypeInference::postorder(IR::Cast *expression) { + if (done()) return expression; + const IR::Type *sourceType = getType(expression->expr); + const IR::Type *castType = getTypeType(expression->destType); + if (sourceType == nullptr || castType == nullptr) return expression; + + auto concreteType = castType; + if (auto tsc = castType->to()) concreteType = tsc->substituted; + if (auto st = concreteType->to()) { + if (auto se = expression->expr->to()) { + // Interpret (S) { kvpairs } as a struct initializer expression + // instead of a cast to a struct. + if (se->type == nullptr || se->type->is() || + se->type->is()) { + auto type = castType->getP4Type(); + setType(type, new IR::Type_Type(st)); + auto sie = new IR::StructExpression(se->srcInfo, type, se->components); + auto result = postorder(sie); // may insert casts + setType(result, st); + if (isCompileTimeConstant(se)) { + setCompileTimeConstant(result->to()); + setCompileTimeConstant(getOriginal()); + } + return result; + } else { + typeError("%1%: cast not supported", expression->destType); + return expression; + } + } else if (expression->expr->is()) { + auto result = assignment(expression, st, expression->expr); + return result; + } else if (auto ih = expression->expr->to()) { + auto type = castType->getP4Type(); + auto concreteCastType = castType; + if (auto ts = castType->to()) + concreteCastType = ts->substituted; + if (concreteCastType->is()) { + setType(type, new IR::Type_Type(castType)); + auto result = new IR::InvalidHeader(ih->srcInfo, type, type); + setType(result, castType); + return result; + } else if (concreteCastType->is()) { + setType(type, new IR::Type_Type(castType)); + auto result = new IR::InvalidHeaderUnion(ih->srcInfo, type, type); + setType(result, castType); + return result; + } else { + typeError("%1%: 'invalid' expression type `%2%` must be a header or header union", + expression, castType); + return expression; + } + } + } + if (auto lt = concreteType->to()) { + auto listElementType = lt->elementType; + if (auto le = expression->expr->to()) { + IR::Vector vec; + bool isConstant = true; + for (size_t i = 0; i < le->size(); i++) { + auto compI = le->components.at(i); + auto src = assignment(expression, listElementType, compI); + if (!isCompileTimeConstant(src)) isConstant = false; + vec.push_back(src); + } + auto vecType = castType->getP4Type(); + setType(vecType, new IR::Type_Type(lt)); + auto result = new IR::P4ListExpression(le->srcInfo, vec, listElementType->getP4Type()); + setType(result, lt); + if (isConstant) { + setCompileTimeConstant(result); + setCompileTimeConstant(getOriginal()); + } + return result; + } else { + typeError("%1%: casts to list not supported", expression); + return expression; + } + } + if (concreteType->is()) { + if (expression->expr->is()) { + auto result = assignment(expression, concreteType, expression->expr); + return result; + } else { + typeError("%1%: casts to header stack not supported", expression); + return expression; + } + } + + if (!castType->is() && !castType->is() && + !castType->is() && !castType->is() && + !castType->is() && !castType->is()) { + typeError("%1%: cast not supported", expression->destType); + return expression; + } + + if (!canCastBetween(castType, sourceType)) { + // This cast is not legal directly, but let's try to see whether + // performing a substitution can help. This will allow the use + // of constants on the RHS. + const IR::Type *destType = castType; + while (destType->is()) + destType = getTypeType(destType->to()->type); + + auto tvs = unify(expression, destType, sourceType, "Cannot cast from '%1%' to '%2%'", + {sourceType, castType}); + if (tvs == nullptr) return expression; + const IR::Expression *rhs = expression->expr; + if (!tvs->isIdentity()) { + ConstantTypeSubstitution cts(tvs, typeMap, this); + rhs = cts.convert(expression->expr, getChildContext()); // sets type + } + if (rhs != expression->expr) { + // if we are here we have performed a substitution on the rhs + expression = new IR::Cast(expression->srcInfo, expression->destType, rhs); + sourceType = getTypeType(expression->destType); + } + if (!canCastBetween(castType, sourceType)) + typeError("%1%: Illegal cast from %2% to %3%", expression, sourceType->toString(), + castType->toString()); + } + setType(expression, castType); + setType(getOriginal(), castType); + if (isCompileTimeConstant(expression->expr)) { + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + } + return expression; +} + +const IR::Node *TypeInference::postorder(IR::PathExpression *expression) { + if (done()) return expression; + auto decl = getDeclaration(expression->path, !errorOnNullDecls); + if (errorOnNullDecls && decl == nullptr) { + typeError("%1%: Cannot resolve declaration", expression); + return expression; + } + const IR::Type *type = nullptr; + if (auto tbl = decl->to()) { + if (auto current = findContext()) { + if (current->name == tbl->name) { + typeError("%1%: Cannot refer to the containing table %2%", expression, tbl); + return expression; + } + } + } else if (decl->is() || decl->is()) { + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + } + // For MatchKind and Errors all ids have a type that has been set + // while processing Type_Error or Declaration_Matchkind + auto declType = typeMap->getType(decl->getNode()); + if (decl->is() && declType && + (declType->is() || declType->is())) { + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + } + + if (decl->is()) { + type = IR::Type_State::get(); + } else if (decl->is()) { + setLeftValue(expression); + setLeftValue(getOriginal()); + } else if (decl->is()) { + auto paramDecl = decl->to(); + if (paramDecl->direction == IR::Direction::InOut || + paramDecl->direction == IR::Direction::Out) { + setLeftValue(expression); + setLeftValue(getOriginal()); + } else if (paramDecl->direction == IR::Direction::None) { + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + } + } else if (decl->is() || decl->is()) { + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + } else if (decl->is() || decl->is()) { + type = getType(decl->getNode()); + // Each method invocation uses fresh type variables + if (type != nullptr) + // may be nullptr because typechecking may have failed + type = cloneWithFreshTypeVariables(type->to()); + } else if (decl->is()) { + typeError("%1%: Type cannot be used here, expecting an expression.", expression); + return expression; + } + + if (type == nullptr) { + type = getType(decl->getNode()); + if (type == nullptr) return expression; + } + + setType(getOriginal(), type); + setType(expression, type); + return expression; +} + +const IR::Node *TypeInference::postorder(IR::Slice *expression) { + if (done()) return expression; + const IR::Type *type = getType(expression->e0); + if (type == nullptr) return expression; + + if (auto se = type->to()) type = getTypeType(se->type); + + if (!type->is()) { + typeError("%1%: bit extraction only defined for bit<> types", expression); + return expression; + } + + auto e1type = getType(expression->e1); + if (e1type && e1type->is()) { + auto ei = EnumInstance::resolve(expression->e1, typeMap); + CHECK_NULL(ei); + auto sei = ei->to(); + if (sei == nullptr) { + typeError("%1%: slice bit index values must be constants", expression->e1); + return expression; + } + expression->e1 = sei->value; + } + auto e2type = getType(expression->e2); + if (e2type && e2type->is()) { + auto ei = EnumInstance::resolve(expression->e2, typeMap); + CHECK_NULL(ei); + auto sei = ei->to(); + if (sei == nullptr) { + typeError("%1%: slice bit index values must be constants", expression->e2); + return expression; + } + expression->e2 = sei->value; + } + + auto bst = type->to(); + if (!expression->e1->is()) { + typeError("%1%: slice bit index values must be constants", expression->e1); + return expression; + } + if (!expression->e2->is()) { + typeError("%1%: slice bit index values must be constants", expression->e2); + return expression; + } + + auto msb = expression->e1->checkedTo(); + auto lsb = expression->e2->checkedTo(); + if (!msb->fitsInt()) { + typeError("%1%: bit index too large", msb); + return expression; + } + if (!lsb->fitsInt()) { + typeError("%1%: bit index too large", lsb); + return expression; + } + int m = msb->asInt(); + int l = lsb->asInt(); + if (m < 0) { + typeError("%1%: negative bit index %2%", expression, msb); + return expression; + } + if (l < 0) { + typeError("%1%: negative bit index %2%", expression, lsb); + return expression; + } + if (m >= bst->size) { + typeError("Bit index %1% greater than width %2%", msb, bst->size); + return expression; + } + if (l >= bst->size) { + typeError("Bit index %1% greater than width %2%", msb, bst->size); + return expression; + } + if (l > m) { + typeError("LSB index %1% greater than MSB index %2%", lsb, msb); + return expression; + } + + const IR::Type *resultType = IR::Type_Bits::get(bst->srcInfo, m - l + 1, false); + resultType = canonicalize(resultType); + if (resultType == nullptr) return expression; + setType(getOriginal(), resultType); + setType(expression, resultType); + if (isLeftValue(expression->e0)) { + setLeftValue(expression); + setLeftValue(getOriginal()); + } + if (isCompileTimeConstant(expression->e0)) { + auto result = constantFold(expression); + setCompileTimeConstant(result); + setCompileTimeConstant(getOriginal()); + return result; + } + return expression; +} + +const IR::Node *TypeInference::postorder(IR::Dots *expression) { + if (done()) return expression; + setType(expression, IR::Type_Any::get()); + setType(getOriginal(), IR::Type_Any::get()); + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + return expression; +} + +const IR::Node *TypeInference::postorder(IR::Mux *expression) { + if (done()) return expression; + const IR::Type *firstType = getType(expression->e0); + const IR::Type *secondType = getType(expression->e1); + const IR::Type *thirdType = getType(expression->e2); + if (firstType == nullptr || secondType == nullptr || thirdType == nullptr) return expression; + + if (!firstType->is()) { + typeError("Selector of %1% must be bool, not %2%", expression->getStringOp(), + firstType->toString()); + return expression; + } + + if (secondType->is() && thirdType->is()) { + typeError("Width must be specified for at least one of %1% or %2%", expression->e1, + expression->e2); + return expression; + } + auto tvs = unify(expression, secondType, thirdType, + "The expressions in a ?: conditional have different types '%1%' and '%2%'", + {secondType, thirdType}); + if (tvs != nullptr) { + if (!tvs->isIdentity()) { + ConstantTypeSubstitution cts(tvs, typeMap, this); + auto e1 = cts.convert(expression->e1, getChildContext()); + auto e2 = cts.convert(expression->e2, getChildContext()); + if (::errorCount() > 0) return expression; + expression->e1 = e1; + expression->e2 = e2; + secondType = typeMap->getType(e1); + } + setType(expression, secondType); + setType(getOriginal(), secondType); + if (isCompileTimeConstant(expression->e0) && isCompileTimeConstant(expression->e1) && + isCompileTimeConstant(expression->e2)) { + auto result = constantFold(expression); + setCompileTimeConstant(result); + setCompileTimeConstant(getOriginal()); + return result; + } + } + return expression; +} + +const IR::Node *TypeInference::postorder(IR::TypeNameExpression *expression) { + if (done()) return expression; + auto type = getType(expression->typeName); + if (type == nullptr) return expression; + setType(getOriginal(), type); + setType(expression, type); + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + return expression; +} + +const IR::Node *TypeInference::postorder(IR::Member *expression) { + if (done()) return expression; + auto type = getType(expression->expr); + if (type == nullptr) return expression; + + cstring member = expression->member.name; + if (auto ts = type->to()) type = ts->substituted; + + if (auto *ext = type->to()) { + auto call = findContext(); + if (call == nullptr) { + typeError("%1%: Methods can only be called", expression); + return expression; + } + auto method = ext->lookupMethod(expression->member, call->arguments); + if (method == nullptr) { + typeError("%1%: extern %2% does not have method matching this call", expression, + ext->name); + return expression; + } + + const IR::Type *methodType = getType(method); + if (methodType == nullptr) return expression; + // Each method invocation uses fresh type variables + methodType = cloneWithFreshTypeVariables(methodType->to()); + + setType(getOriginal(), methodType); + setType(expression, methodType); + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + return expression; + } + + bool inMethod = getParent() != nullptr; + // Built-in methods + if (inMethod && (member == IR::Type::minSizeInBits || member == IR::Type::minSizeInBytes || + member == IR::Type::maxSizeInBits || member == IR::Type::maxSizeInBytes)) { + auto type = new IR::Type_Method(IR::Type_InfInt::get(), new IR::ParameterList(), member); + auto ctype = canonicalize(type); + if (ctype == nullptr) return expression; + setType(getOriginal(), ctype); + setType(expression, ctype); + return expression; + } + + if (type->is()) { + std::string typeStr = "structure "; + if (type->is() || type->is()) { + typeStr = ""; + if (inMethod && (member == IR::Type_Header::isValid)) { + // Built-in method + auto type = + new IR::Type_Method(IR::Type_Boolean::get(), new IR::ParameterList(), member); + auto ctype = canonicalize(type); + if (ctype == nullptr) return expression; + setType(getOriginal(), ctype); + setType(expression, ctype); + return expression; + } + } + if (type->is()) { + if (inMethod && + (member == IR::Type_Header::setValid || member == IR::Type_Header::setInvalid)) { + if (!isLeftValue(expression->expr)) + typeError("%1%: must be applied to a left-value", expression); + // Built-in method + auto type = + new IR::Type_Method(IR::Type_Void::get(), new IR::ParameterList, member); + auto ctype = canonicalize(type); + if (ctype == nullptr) return expression; + setType(getOriginal(), ctype); + setType(expression, ctype); + return expression; + } + } + + auto stb = type->to(); + auto field = stb->getField(member); + if (field == nullptr) { + typeError("Field %1% is not a member of %2%%3%", expression->member, typeStr, stb); + return expression; + } + + auto fieldType = getTypeType(field->type); + if (fieldType == nullptr) return expression; + if (fieldType->is() && !getParent()) { + typeError("%1%: only allowed in switch statements", expression); + return expression; + } + setType(getOriginal(), fieldType); + setType(expression, fieldType); + if (isLeftValue(expression->expr)) { + setLeftValue(expression); + setLeftValue(getOriginal()); + } else { + LOG2("No left value " << expression->expr); + } + if (isCompileTimeConstant(expression->expr)) { + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + } + return expression; + } + + if (auto *apply = type->to(); apply && member == IR::IApply::applyMethodName) { + auto *methodType = apply->getApplyMethodType(); + auto *canon = canonicalize(methodType); + if (!canon) return expression; + methodType = canon->to(); + if (methodType == nullptr) return expression; + learn(methodType, this, getChildContext()); + setType(getOriginal(), methodType); + setType(expression, methodType); + return expression; + } + + if (auto *stack = type->to()) { + auto parser = findContext(); + if (member == IR::Type_Stack::next || member == IR::Type_Stack::last) { + if (parser == nullptr) { + typeError("%1%: 'last', and 'next' for stacks can only be used in a parser", + expression); + return expression; + } + setType(getOriginal(), stack->elementType); + setType(expression, stack->elementType); + if (isLeftValue(expression->expr) && member == IR::Type_Stack::next) { + setLeftValue(expression); + setLeftValue(getOriginal()); + } + return expression; + } else if (member == IR::Type_Stack::arraySize) { + setType(getOriginal(), IR::Type_Bits::get(32)); + setType(expression, IR::Type_Bits::get(32)); + return expression; + } else if (member == IR::Type_Stack::lastIndex) { + if (parser == nullptr) { + typeError("%1%: 'lastIndex' for stacks can only be used in a parser", expression); + return expression; + } + setType(getOriginal(), IR::Type_Bits::get(32, false)); + setType(expression, IR::Type_Bits::get(32, false)); + return expression; + } else if (member == IR::Type_Stack::push_front || member == IR::Type_Stack::pop_front) { + if (parser != nullptr) + typeError("%1%: '%2%' and '%3%' for stacks cannot be used in a parser", expression, + IR::Type_Stack::push_front, IR::Type_Stack::pop_front); + if (!isLeftValue(expression->expr)) + typeError("%1%: must be applied to a left-value", expression); + auto params = new IR::IndexedVector(); + auto param = new IR::Parameter(IR::ID("count"_cs, nullptr), IR::Direction::None, + IR::Type_InfInt::get()); + auto tt = new IR::Type_Type(param->type); + setType(param->type, tt); + setType(param, param->type); + params->push_back(param); + auto type = + new IR::Type_Method(IR::Type_Void::get(), new IR::ParameterList(*params), member); + auto canon = canonicalize(type); + if (canon == nullptr) return expression; + setType(getOriginal(), canon); + setType(expression, canon); + return expression; + } + } + + if (auto *tt = type->to()) { + auto base = tt->type; + if (base->is() || base->is() || + base->is()) { + if (isCompileTimeConstant(expression->expr)) { + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + } + auto fbase = base->to(); + if (auto decl = fbase->getDeclByName(member)) { + if (auto ftype = getType(decl->getNode())) { + setType(getOriginal(), ftype); + setType(expression, ftype); + } + } else { + typeError("%1%: Invalid enum tag", expression); + setType(getOriginal(), type); + setType(expression, type); + } + return expression; + } + } + + typeError("Cannot extract field %1% from %2% which has type %3%", expression->member, + expression->expr, type->toString()); + // unreachable + return expression; +} + +// If inActionList this call is made in the "action" property of a table +const IR::Expression *TypeInference::actionCall(bool inActionList, + const IR::MethodCallExpression *actionCall) { + // If a is an action with signature _(arg1, arg2, arg3) + // Then the call a(arg1, arg2) is also an + // action, with signature _(arg3) + LOG2("Processing action " << dbp(actionCall)); + + if (findContext()) { + typeError("%1%: Action calls are not allowed within parsers", actionCall); + return actionCall; + } + auto method = actionCall->method; + auto methodType = getType(method); + if (!methodType) return actionCall; // error emitted in getType + auto baseType = methodType->to(); + if (!baseType) { + typeError("%1%: must be an action", method); + return actionCall; + } + LOG2("Action type " << baseType); + BUG_CHECK(method->is(), "%1%: unexpected call", method); + BUG_CHECK(baseType->returnType == nullptr, "%1%: action with return type?", + baseType->returnType); + if (!baseType->typeParameters->empty()) { + typeError("%1%: Actions cannot be generic", baseType->typeParameters); + return actionCall; + } + if (!actionCall->typeArguments->empty()) { + typeError("%1%: Cannot supply type parameters for an action invocation", + actionCall->typeArguments); + return actionCall; + } + + bool inTable = findContext() != nullptr; + + TypeConstraints constraints(typeMap->getSubstitutions(), typeMap); + auto params = new IR::ParameterList; + + // keep track of parameters that have not been matched yet + absl::flat_hash_map left; + for (auto p : baseType->parameters->parameters) left.emplace(p->name, p); + + auto paramIt = baseType->parameters->parameters.begin(); + auto newArgs = new IR::Vector(); + bool changed = false; + for (auto arg : *actionCall->arguments) { + cstring argName = arg->name.name; + bool named = !argName.isNullOrEmpty(); + const IR::Parameter *param; + auto newExpr = arg->expression; + + if (named) { + param = baseType->parameters->getParameter(argName); + if (param == nullptr) { + typeError("%1%: No parameter named %2%", baseType->parameters, arg->name); + return actionCall; + } + } else { + if (paramIt == baseType->parameters->parameters.end()) { + typeError("%1%: Too many arguments for action", actionCall); + return actionCall; + } + param = *paramIt; + } + + LOG2("Action parameter " << dbp(param)); + if (!left.erase(param->name)) { + // This should have been checked by the CheckNamedArgs pass. + BUG("%1%: Duplicate argument name?", param->name); + } + + auto paramType = getType(param); + auto argType = getType(arg); + if (paramType == nullptr || argType == nullptr) + // type checking failed before + return actionCall; + constraints.addImplicitCastConstraint(actionCall, paramType, argType); + if (param->direction == IR::Direction::None) { + if (inActionList) { + typeError("%1%: parameter %2% cannot be bound: it is set by the control plane", arg, + param); + } else if (inTable) { + // For actions None parameters are treated as IN + // parameters when the action is called directly. We + // don't require them to be bound to a compile-time + // constant. But if the action is instantiated in a + // table (as default_action or entries), then the + // arguments do have to be compile-time constants. + if (!isCompileTimeConstant(arg->expression)) + typeError("%1%: action argument must be a compile-time constant", + arg->expression); + } + // This is like an assignment; may make additional conversions. + newExpr = assignment(arg, param->type, arg->expression); + if (readOnly) { + // FIXME -- if we're in readonly mode, we should not have introduced any mods + // here, but there's a bug in the DPDK backend where it generates a ListExpression + // that would be converted to a StructExpression, and other problems where it + // can't deal with that StructExpressions, so we hack to avoid breaking those tests + newExpr = arg->expression; + } + } else if (param->direction == IR::Direction::Out || + param->direction == IR::Direction::InOut) { + if (!isLeftValue(arg->expression)) + typeError("%1%: must be a left-value", arg->expression); + } else { + // This is like an assignment; may make additional conversions. + newExpr = assignment(arg, param->type, arg->expression); + } + if (::errorCount() > 0) return actionCall; + if (newExpr != arg->expression) { + LOG2("Changing action argument to " << newExpr); + changed = true; + newArgs->push_back(new IR::Argument(arg->srcInfo, arg->name, newExpr)); + } else { + newArgs->push_back(arg); + } + if (!named) ++paramIt; + } + if (changed) + actionCall = + new IR::MethodCallExpression(actionCall->srcInfo, actionCall->type, actionCall->method, + actionCall->typeArguments, newArgs); + + // Check remaining parameters: they must be all non-directional + bool error = false; + for (auto p : left) { + if (p.second->direction != IR::Direction::None && p.second->defaultValue == nullptr) { + typeError("%1%: Parameter %2% must be bound", actionCall, p.second); + error = true; + } + } + if (error) return actionCall; + + auto resultType = new IR::Type_Action(baseType->srcInfo, baseType->typeParameters, params); + + setType(getOriginal(), resultType); + setType(actionCall, resultType); + auto tvs = constraints.solve(); + if (tvs == nullptr || errorCount() > 0) return actionCall; + addSubstitutions(tvs); + + ConstantTypeSubstitution cts(tvs, typeMap, this); + actionCall = cts.convert(actionCall, getChildContext()) + ->to(); // cast arguments + if (::errorCount() > 0) return actionCall; + + LOG2("Converted action " << actionCall); + setType(actionCall, resultType); + return actionCall; +} + +const IR::Node *TypeInference::postorder(IR::MethodCallStatement *mcs) { + // Remove mcs if child methodCall resolves to a compile-time constant. + return !mcs->methodCall ? nullptr : mcs; +} + +const IR::Node *TypeInference::postorder(IR::MethodCallExpression *expression) { + if (done()) return expression; + LOG2("Solving method call " << dbp(expression)); + auto methodType = getType(expression->method); + if (methodType == nullptr) return expression; + auto methodBaseType = methodType->to(); + if (methodBaseType == nullptr) { + typeError("%1% is not a method", expression); + return expression; + } + + // Handle differently methods and actions: action invocations return actions + // with different signatures + if (methodType->is()) { + if (findContext()) { + typeError("%1%: Functions cannot call actions", expression); + return expression; + } + bool inActionsList = false; + auto prop = findContext(); + if (prop != nullptr && prop->name == IR::TableProperties::actionsPropertyName) + inActionsList = true; + return actionCall(inActionsList, expression); + } else { + // Constant-fold constant expressions + if (auto mem = expression->method->to()) { + auto type = typeMap->getType(mem->expr, true); + if (((mem->member == IR::Type::minSizeInBits || + mem->member == IR::Type::minSizeInBytes || + mem->member == IR::Type::maxSizeInBits || + mem->member == IR::Type::maxSizeInBytes)) && + !type->is() && expression->typeArguments->size() == 0 && + expression->arguments->size() == 0) { + auto max = mem->member.name.startsWith("max"); + int w = typeMap->widthBits(type, expression, max); + LOG3("Folding " << mem << " to " << w); + if (w < 0) return expression; + if (mem->member.name.endsWith("Bytes")) w = ROUNDUP(w, 8); + if (getParent()) return nullptr; + auto result = new IR::Constant(expression->srcInfo, w); + auto tt = new IR::Type_Type(result->type); + setType(result->type, tt); + setType(result, result->type); + setCompileTimeConstant(result); + return result; + } + if (mem->member == IR::Type_Header::isValid && type->is()) { + const IR::BoolLiteral *lit = nullptr; + if (mem->expr->is()) + lit = new IR::BoolLiteral(expression->srcInfo, false); + if (mem->expr->is()) + lit = new IR::BoolLiteral(expression->srcInfo, false); + if (mem->expr->is()) + lit = new IR::BoolLiteral(expression->srcInfo, true); + if (lit) { + LOG3("Folding " << mem << " to " << lit); + if (getParent()) return nullptr; + setType(lit, IR::Type_Boolean::get()); + setCompileTimeConstant(lit); + return lit; + } + } + } + + if (getContext()->node->is()) { + typeError("%1% is not invoking an action", expression); + return expression; + } + + // We build a type for the callExpression and unify it with the method expression + // Allocate a fresh variable for the return type; it will be hopefully bound in the process. + auto rettype = new IR::Type_Var(IR::ID(nameGen->newName("R"), ""_cs)); + auto args = new IR::Vector(); + bool constArgs = true; + for (auto aarg : *expression->arguments) { + auto arg = aarg->expression; + auto argType = getType(arg); + if (argType == nullptr) return expression; + auto argInfo = new IR::ArgumentInfo(arg->srcInfo, isLeftValue(arg), + isCompileTimeConstant(arg), argType, aarg); + args->push_back(argInfo); + constArgs = constArgs && isCompileTimeConstant(arg); + } + auto typeArgs = new IR::Vector(); + for (auto ta : *expression->typeArguments) { + auto taType = getTypeType(ta); + if (taType == nullptr) return expression; + typeArgs->push_back(taType); + } + auto callType = new IR::Type_MethodCall(expression->srcInfo, typeArgs, rettype, args); + + auto tvs = unify(expression, methodBaseType, callType, + "Function type '%1%' does not match invocation type '%2%'", + {methodBaseType, callType}); + if (tvs == nullptr) return expression; + + // Infer Dont_Care for type vars used only in not-present optional params + auto dontCares = new TypeVariableSubstitution(); + auto typeParams = methodBaseType->typeParameters; + for (auto p : *methodBaseType->parameters) { + if (!p->isOptional()) continue; + forAllMatching( + p, [tvs, dontCares, typeParams, this](const IR::Type_Var *tv) { + if (typeMap->getSubstitutions()->lookup(tv) != nullptr) + return; // already bound + if (tvs->lookup(tv)) return; // already bound + if (typeParams->getDeclByName(tv->name) != tv) return; // not a tv of this call + dontCares->setBinding(tv, IR::Type_Dontcare::get()); + }); + } + addSubstitutions(dontCares); + + LOG2("Method type before specialization " << methodType << " with " << tvs); + TypeVariableSubstitutionVisitor substVisitor(tvs); + substVisitor.setCalledBy(this); + auto specMethodType = methodType->apply(substVisitor); + LOG2("Method type after specialization " << specMethodType); + learn(specMethodType, this, getChildContext()); + + auto canon = getType(specMethodType); + if (canon == nullptr) return expression; + + auto functionType = specMethodType->to(); + BUG_CHECK(functionType != nullptr, "Method type is %1%", specMethodType); + + if (!functionType->is()) + BUG("Unexpected type for function %1%", functionType); + + auto returnType = tvs->lookup(rettype); + if (returnType == nullptr) { + typeError("Cannot infer a concrete return type for this call of %1%", expression); + return expression; + } + // The return type may also contain type variables + returnType = returnType->apply(substVisitor)->to(); + learn(returnType, this, getChildContext()); + if (returnType->is() || returnType->is() || + returnType->is() || returnType->is() || + returnType->is() || + (returnType->is() && !constArgs)) { + // Experimental: methods with all constant arguments can return an extern + // instance as a factory method evaluated at compile time. + typeError("%1%: illegal return type %2%", expression, returnType); + return expression; + } + + setType(getOriginal(), returnType); + setType(expression, returnType); + + ConstantTypeSubstitution cts(tvs, typeMap, this); + auto result = expression; + // Arguments may need to be cast, e.g., list expression to a + // header type. + auto paramIt = functionType->parameters->begin(); + auto newArgs = new IR::Vector(); + bool changed = false; + for (auto arg : *expression->arguments) { + cstring argName = arg->name.name; + bool named = !argName.isNullOrEmpty(); + const IR::Parameter *param; + + if (named) { + param = functionType->parameters->getParameter(argName); + } else { + param = *paramIt; + } + + if (param->type->is()) + typeError( + "%1%: Could not infer a type for parameter %2% " + "(inferred type is don't care '_')", + arg, param); + + // By calling generic functions with don't care parameters + // we can force parameters to have illegal types. Check here for this case. + // e.g., void f(in T arg); table t { }; f(t); + if (param->type->is() || param->type->is() || + param->type->is() || param->type->is() || + param->type->is()) + typeError("%1%: argument cannot have type %2%", arg, param->type); + + auto newExpr = arg->expression; + if (param->direction == IR::Direction::In) { + // This is like an assignment; may make additional conversions. + newExpr = assignment(arg, param->type, arg->expression); + } else { + // Insert casts for 'int' values. + newExpr = cts.convert(newExpr, getChildContext())->to(); + } + if (::errorCount() > 0) return expression; + if (newExpr != arg->expression) { + LOG2("Changing method argument to " << newExpr); + changed = true; + newArgs->push_back(new IR::Argument(arg->srcInfo, arg->name, newExpr)); + } else { + newArgs->push_back(arg); + } + if (!named) ++paramIt; + } + + if (changed) + result = new IR::MethodCallExpression(result->srcInfo, result->type, result->method, + result->typeArguments, newArgs); + setType(result, returnType); + + auto mi = MethodInstance::resolve(result, this, typeMap, getChildContext(), true); + if (mi->isApply() && findContext()) { + typeError("%1%: apply cannot be called from actions", expression); + return expression; + } + + if (const auto *ef = mi->to()) { + const IR::Type *baseReturnType = returnType; + if (const auto *sc = returnType->to()) + baseReturnType = sc->baseType; + const bool factoryOrStaticAssert = + baseReturnType->is() || ef->method->name == "static_assert"; + if (constArgs && factoryOrStaticAssert) { + // factory extern function calls (those that return extern objects) with constant + // args are compile-time constants. + // The result of a static_assert call is also a compile-time constant. + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + } + } + + auto bi = mi->to(); + if ((findContext()) && (!bi || (bi->name == IR::Type_Stack::pop_front || + bi->name == IR::Type_Stack::push_front))) { + typeError("%1%: no function calls allowed in this context", expression); + return expression; + } + return result; + } + return expression; +} + +const IR::Node *TypeInference::postorder(IR::ConstructorCallExpression *expression) { + if (done()) return expression; + auto type = getTypeType(expression->constructedType); + if (type == nullptr) return expression; + + auto simpleType = type; + CHECK_NULL(simpleType); + if (auto *sc = type->to()) simpleType = sc->substituted; + + if (auto *e = simpleType->to()) { + auto [contType, newArgs] = checkExternConstructor(expression, e, expression->arguments); + if (newArgs == nullptr) return expression; + expression->arguments = newArgs; + setType(getOriginal(), contType); + setType(expression, contType); + } else if (auto *c = simpleType->to()) { + auto typeAndArgs = containerInstantiation(expression, expression->arguments, c); + auto contType = typeAndArgs.first; + auto args = typeAndArgs.second; + if (contType == nullptr || args == nullptr) return expression; + if (auto *st = type->to()) { + contType = new IR::Type_SpecializedCanonical(type->srcInfo, st->baseType, st->arguments, + contType); + } + expression->arguments = args; + setType(expression, contType); + setType(getOriginal(), contType); + } else { + typeError("%1%: Cannot invoke a constructor on type %2%", expression, type->toString()); + } + + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + return expression; +} + +static void convertStructToTuple(const IR::Type_StructLike *structType, IR::Type_Tuple *tuple) { + for (auto field : structType->fields) { + if (auto ft = field->type->to()) { + tuple->components.push_back(ft); + } else if (auto ft = field->type->to()) { + convertStructToTuple(ft, tuple); + } else if (auto ft = field->type->to()) { + tuple->components.push_back(ft); + } else if (auto ft = field->type->to()) { + tuple->components.push_back(ft); + } else { + typeError("Type not supported %1% for struct field %2% in 'select'", field->type, + field); + } + } +} + +const IR::SelectCase *TypeInference::matchCase(const IR::SelectExpression *select, + const IR::Type_BaseList *selectType, + const IR::SelectCase *selectCase, + const IR::Type *caseType) { + // The selectType is always a tuple + // If the caseType is a set type, we unify the type of the set elements + if (auto *set = caseType->to()) caseType = set->elementType; + // The caseType may be a simple type, and then we have to unwrap the selectType + if (caseType->is()) return selectCase; + + if (auto *sl = caseType->to()) { + auto tupleType = new IR::Type_Tuple(); + convertStructToTuple(sl, tupleType); + caseType = tupleType; + } + const IR::Type *useSelType = selectType; + if (!caseType->is()) { + if (selectType->components.size() != 1) { + typeError("Type mismatch %1% (%2%) vs %3% (%4%)", select->select, + selectType->toString(), selectCase, caseType->toString()); + return nullptr; + } + useSelType = selectType->components.at(0); + } + auto tvs = unifyCast( + select, useSelType, caseType, + "'match' case label '%1%' has type '%2%' which does not match the expected type '%3%'", + {selectCase->keyset, caseType, useSelType}); + if (tvs == nullptr) return nullptr; + ConstantTypeSubstitution cts(tvs, typeMap, this); + auto ks = cts.convert(selectCase->keyset, getChildContext()); + if (::errorCount() > 0) return selectCase; + + if (ks != selectCase->keyset) + selectCase = new IR::SelectCase(selectCase->srcInfo, ks, selectCase->state); + return selectCase; +} + +const IR::Node *TypeInference::postorder(IR::This *expression) { + if (done()) return expression; + auto decl = findContext(); + if (findContext() == nullptr || decl == nullptr) + typeError("%1%: can only be used in the definition of an abstract method", expression); + auto type = getType(decl); + setType(expression, type); + setType(getOriginal(), type); + return expression; +} + +const IR::Node *TypeInference::postorder(IR::DefaultExpression *expression) { + if (!done()) { + setType(expression, IR::Type_Dontcare::get()); + setType(getOriginal(), IR::Type_Dontcare::get()); + } + setCompileTimeConstant(expression); + setCompileTimeConstant(getOriginal()); + return expression; +} + +bool TypeInference::containsHeader(const IR::Type *type) { + if (type->is() || type->is() || + type->is()) + return true; + if (auto *st = type->to()) { + for (auto f : st->fields) + if (containsHeader(f->type)) return true; + } + return false; +} + +/// Expressions that appear in a select expression are restricted to a small +/// number of types: bits, enums, serializable enums, and booleans. +static bool validateSelectTypes(const IR::Type *type, const IR::SelectExpression *expression) { + if (auto tuple = type->to()) { + for (auto ct : tuple->components) { + auto check = validateSelectTypes(ct, expression); + if (!check) return false; + } + return true; + } else if (type->is() || type->is() || + type->is() || type->is()) { + return true; + } + typeError("Expression '%1%' with a component of type '%2%' cannot be used in 'select'", + expression->select, type); + return false; +} + +const IR::Node *TypeInference::postorder(IR::SelectExpression *expression) { + if (done()) return expression; + auto selectType = getType(expression->select); + if (selectType == nullptr) return expression; + + // Check that the selectType is determined + auto tuple = selectType->to(); + BUG_CHECK(tuple != nullptr, "%1%: Expected a tuple type for the select expression, got %2%", + expression, selectType); + if (!validateSelectTypes(selectType, expression)) return expression; + + bool changes = false; + IR::Vector vec; + for (auto sc : expression->selectCases) { + auto type = getType(sc->keyset); + if (type == nullptr) return expression; + auto newsc = matchCase(expression, tuple, sc, type); + vec.push_back(newsc); + if (newsc != sc) changes = true; + } + if (changes) + expression = + new IR::SelectExpression(expression->srcInfo, expression->select, std::move(vec)); + setType(expression, IR::Type_State::get()); + setType(getOriginal(), IR::Type_State::get()); + return expression; +} + +const IR::Node *TypeInference::postorder(IR::AttribLocal *local) { + setType(local, local->type); + setType(getOriginal(), local->type); + return local; +} + +} // namespace P4 diff --git a/frontends/p4/typeChecking/typeCheckStmt.cpp b/frontends/p4/typeChecking/typeCheckStmt.cpp new file mode 100644 index 00000000000..9d22f41d62f --- /dev/null +++ b/frontends/p4/typeChecking/typeCheckStmt.cpp @@ -0,0 +1,339 @@ +/* +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "frontends/p4/methodInstance.h" +#include "syntacticEquivalence.h" +#include "typeChecker.h" + +namespace P4 { + +const IR::Node *TypeInference::postorder(IR::IfStatement *conditional) { + LOG3("TI Visiting " << dbp(getOriginal())); + auto type = getType(conditional->condition); + if (type == nullptr) return conditional; + if (!type->is()) + typeError("Condition of %1% does not evaluate to a bool but %2%", conditional, + type->toString()); + return conditional; +} + +const IR::Node *TypeInference::postorder(IR::SwitchStatement *stat) { + LOG3("TI Visiting " << dbp(getOriginal())); + auto type = getType(stat->expression); + if (type == nullptr) return stat; + + if (auto ae = type->to()) { + // switch (table.apply(...)) + absl::flat_hash_map foundLabels; + const IR::Node *foundDefault = nullptr; + for (auto c : stat->cases) { + if (c->label->is()) { + if (foundDefault) + typeError("%1%: multiple 'default' labels %2%", c->label, foundDefault); + foundDefault = c->label; + continue; + } else if (auto pe = c->label->to()) { + cstring label = pe->path->name.name; + auto [it, inserted] = foundLabels.emplace(label, c->label); + if (!inserted) + typeError("%1%: 'switch' label duplicates %2%", c->label, it->second); + if (!ae->contains(label)) + typeError("%1% is not a legal label (action name)", c->label); + } else { + typeError("%1%: 'switch' label must be an action name or 'default'", c->label); + } + } + } else { + // switch (expression) + Comparison comp; + comp.left = stat->expression; + if (isCompileTimeConstant(stat->expression)) + warn(ErrorType::WARN_MISMATCH, "%1%: constant expression in switch", stat->expression); + + for (auto &c : stat->cases) { + if (!isCompileTimeConstant(c->label)) + typeError("%1%: must be a compile-time constant", c->label); + auto lt = getType(c->label); + if (lt == nullptr) continue; + if (lt->is() && type->is()) { + c = new IR::SwitchCase(c->srcInfo, new IR::Cast(c->label->srcInfo, type, c->label), + c->statement); + setType(c->label, type); + setCompileTimeConstant(c->label); + continue; + } else if (c->label->is()) { + continue; + } + comp.right = c->label; + bool b = compare(stat, type, lt, &comp); + if (b && comp.right != c->label) { + c = new IR::SwitchCase(c->srcInfo, comp.right, c->statement); + setCompileTimeConstant(c->label); + } + } + } + return stat; +} + +const IR::Node *TypeInference::postorder(IR::ReturnStatement *statement) { + LOG3("TI Visiting " << dbp(getOriginal())); + auto func = findOrigCtxt(); + if (func == nullptr) { + if (statement->expression != nullptr) + typeError("%1%: return with expression can only be used in a function", statement); + return statement; + } + + auto ftype = getType(func); + if (ftype == nullptr) return statement; + + BUG_CHECK(ftype->is(), "%1%: expected a method type for function", ftype); + auto mt = ftype->to(); + auto returnType = mt->returnType; + CHECK_NULL(returnType); + if (returnType->is()) { + if (statement->expression != nullptr) + typeError("%1%: return expression in function with void return", statement); + return statement; + } + + if (statement->expression == nullptr) { + typeError("%1%: return with no expression in a function returning %2%", statement, + returnType->toString()); + return statement; + } + + auto init = assignment(statement, returnType, statement->expression); + if (init != statement->expression) statement->expression = init; + return statement; +} + +const IR::Node *TypeInference::postorder(IR::AssignmentStatement *assign) { + LOG3("TI Visiting " << dbp(getOriginal())); + auto ltype = getType(assign->left); + if (ltype == nullptr) return assign; + + if (!isLeftValue(assign->left)) { + typeError("Expression %1% cannot be the target of an assignment", assign->left); + LOG2(assign->left); + return assign; + } + + auto newInit = assignment(assign, ltype, assign->right); + if (newInit != assign->right) + assign = new IR::AssignmentStatement(assign->srcInfo, assign->left, newInit); + return assign; +} + +const IR::Node *TypeInference::postorder(IR::ForInStatement *forin) { + LOG3("TI Visiting " << dbp(getOriginal())); + auto ltype = getType(forin->ref); + if (ltype == nullptr) return forin; + auto ctype = getType(forin->collection); + if (ctype == nullptr) return forin; + + if (!isLeftValue(forin->ref)) { + typeError("Expression %1% cannot be the target of an assignment", forin->ref); + LOG2(forin->ref); + return forin; + } + if (auto range = forin->collection->to()) { + auto rclone = range->clone(); + rclone->left = assignment(forin, ltype, rclone->left); + rclone->right = assignment(forin, ltype, rclone->right); + if (*range != *rclone) + forin->collection = rclone; + else + delete rclone; + } else if (auto *stack = ctype->to()) { + if (!canCastBetween(stack->elementType, ltype)) + typeError("%1% does not match header stack type %2%", forin->ref, ctype); + } else if (auto *list = ctype->to()) { + if (!canCastBetween(list->elementType, ltype)) + typeError("%1% does not match %2% element type", forin->ref, ctype); + } else { + error(ErrorType::ERR_UNSUPPORTED, + "%1%Typechecking does not support iteration over this collection of type %2%", + forin->collection->srcInfo, ctype); + } + return forin; +} + +const IR::Node *TypeInference::postorder(IR::ActionListElement *elem) { + if (done()) return elem; + auto type = getType(elem->expression); + if (type == nullptr) return elem; + + setType(elem, type); + setType(getOriginal(), type); + return elem; +} + +const IR::Node *TypeInference::postorder(IR::SelectCase *sc) { + auto type = getType(sc->state); + if (type != nullptr && type != IR::Type_State::get()) typeError("%1% must be state", sc); + return sc; +} + +const IR::Node *TypeInference::postorder(IR::KeyElement *elem) { + auto ktype = getType(elem->expression); + if (ktype == nullptr) return elem; + while (ktype->is()) ktype = getTypeType(ktype->to()->type); + if (!ktype->is() && !ktype->is() && + !ktype->is() && !ktype->is() && + !ktype->is()) + typeError("Key %1% field type must be a scalar type; it cannot be %2%", elem->expression, + ktype->toString()); + auto type = getType(elem->matchType); + if (type != nullptr && type != IR::Type_MatchKind::get()) + typeError("%1% must be a %2% value", elem->matchType, + IR::Type_MatchKind::get()->toString()); + if (isCompileTimeConstant(elem->expression) && !readOnly) + warn(ErrorType::WARN_IGNORE_PROPERTY, "%1%: constant key element", elem); + return elem; +} + +const IR::Node *TypeInference::postorder(IR::ActionList *al) { + LOG3("TI Visited " << dbp(al)); + BUG_CHECK(currentActionList == nullptr, "%1%: nested action list?", al); + currentActionList = al; + return al; +} + +const IR::ActionListElement *TypeInference::validateActionInitializer( + const IR::Expression *actionCall) { + // We cannot retrieve the action list from the table, because the + // table has not been modified yet. We want the latest version of + // the action list, as it has been already typechecked. + auto al = currentActionList; + if (al == nullptr) { + auto table = findContext(); + BUG_CHECK(table, "%1%: not within a table", actionCall); + typeError("%1% has no action list, so it cannot invoke '%2%'", table, actionCall); + return nullptr; + } + + auto call = actionCall->to(); + if (call == nullptr) { + typeError("%1%: expected an action call", actionCall); + return nullptr; + } + auto method = call->method; + if (!method->is()) BUG("%1%: unexpected expression", method); + auto pe = method->to(); + auto decl = getDeclaration(pe->path, !errorOnNullDecls); + if (errorOnNullDecls && decl == nullptr) { + typeError("%1%: Cannot resolve declaration", pe); + return nullptr; + } + + auto ale = al->actionList.getDeclaration(decl->getName()); + if (ale == nullptr) { + typeError("%1% not present in action list", call); + return nullptr; + } + + BUG_CHECK(ale->is(), "%1%: expected an ActionListElement", ale); + auto elem = ale->to(); + auto entrypath = elem->getPath(); + auto entrydecl = getDeclaration(entrypath, true); + if (entrydecl != decl) { + typeError("%1% and %2% refer to different actions", actionCall, elem); + return nullptr; + } + + // Check that the data-plane parameters + // match the data-plane parameters for the same action in + // the actions list. + auto actionListCall = elem->expression->to(); + CHECK_NULL(actionListCall); + auto type = typeMap->getType(actionListCall->method); + if (type == nullptr) { + typeError("%1%: action invocation should be after the `actions` list", actionCall); + return nullptr; + } + + if (actionListCall->arguments->size() > call->arguments->size()) { + typeError("%1%: not enough arguments", call); + return nullptr; + } + + SameExpression se(this, typeMap); + auto callInstance = MethodInstance::resolve(call, this, typeMap, getChildContext(), true); + auto listInstance = + MethodInstance::resolve(actionListCall, this, typeMap, getChildContext(), true); + + for (auto param : *listInstance->substitution.getParametersInArgumentOrder()) { + auto aa = listInstance->substitution.lookup(param); + auto da = callInstance->substitution.lookup(param); + if (da == nullptr) { + typeError("%1%: parameter should be assigned in call %2%", param, call); + return nullptr; + } + bool same = se.sameExpression(aa->expression, da->expression); + if (!same) { + typeError("%1%: argument does not match declaration in actions list: %2%", da, aa); + return nullptr; + } + } + + for (auto param : *callInstance->substitution.getParametersInOrder()) { + auto da = callInstance->substitution.lookup(param); + if (da == nullptr) { + typeError("%1%: parameter should be assigned in call %2%", param, call); + return nullptr; + } + } + + return elem; +} + +const IR::Node *TypeInference::postorder(IR::Property *prop) { + // Handle the default_action + if (prop->name == IR::TableProperties::defaultActionPropertyName) { + auto pv = prop->value->to(); + if (pv == nullptr) { + typeError("%1% table property should be an action", prop); + } else { + auto type = getType(pv->expression); + if (type == nullptr) return prop; + if (!type->is()) { + typeError("%1% table property should be an action", prop); + return prop; + } + auto at = type->to(); + if (at->parameters->size() != 0) { + typeError("%1%: parameter %2% does not have a corresponding argument", prop->value, + at->parameters->parameters.at(0)); + return prop; + } + + // Check that the default action appears in the list of actions. + BUG_CHECK(prop->value->is(), "%1% not an expression", prop); + auto def = prop->value->to()->expression; + auto ale = validateActionInitializer(def); + if (ale != nullptr) { + auto anno = ale->getAnnotation(IR::Annotation::tableOnlyAnnotation); + if (anno != nullptr) { + typeError("%1%: Action marked with %2% used as default action", prop, + IR::Annotation::tableOnlyAnnotation); + return prop; + } + } + } + } + return prop; +} + +} // namespace P4 diff --git a/frontends/p4/typeChecking/typeCheckTypes.cpp b/frontends/p4/typeChecking/typeCheckTypes.cpp new file mode 100644 index 00000000000..56f3e48cd49 --- /dev/null +++ b/frontends/p4/typeChecking/typeCheckTypes.cpp @@ -0,0 +1,559 @@ +/* +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "constantTypeSubstitution.h" +#include "typeChecker.h" + +namespace P4 { + +bool hasVarbitsOrUnions(const TypeMap *typeMap, const IR::Type *type) { + // called for a canonical type + if (type->is() || type->is()) { + return true; + } else if (auto ht = type->to()) { + const IR::StructField *varbit = nullptr; + for (auto f : ht->fields) { + auto ftype = typeMap->getType(f); + if (ftype == nullptr) continue; + if (ftype->is()) { + if (varbit == nullptr) { + varbit = f; + } else { + typeError("%1% and %2%: multiple varbit fields in a header", varbit, f); + return type; + } + } + } + return varbit != nullptr; + } else if (auto at = type->to()) { + return hasVarbitsOrUnions(typeMap, at->elementType); + } else if (auto tpl = type->to()) { + for (auto f : tpl->components) { + if (hasVarbitsOrUnions(typeMap, f)) return true; + } + } + return false; +} + +bool TypeInference::onlyBitsOrBitStructs(const IR::Type *type) const { + // called for a canonical type + if (type->is() || type->is() || type->is()) { + return true; + } else if (auto ht = type->to()) { + for (auto f : ht->fields) { + auto ftype = typeMap->getType(f); + BUG_CHECK((ftype != nullptr), + "onlyBitsOrBitStructs check could not find type " + "for %1%", + f); + if (!onlyBitsOrBitStructs(ftype)) return false; + } + return true; + } + return false; +} + +const IR::Type *TypeInference::setTypeType(const IR::Type *type, bool learn) { + if (done()) return type; + const IR::Type *typeToCanonicalize; + if (readOnly) + typeToCanonicalize = getOriginal(); + else + typeToCanonicalize = type; + auto canon = canonicalize(typeToCanonicalize); + if (canon != nullptr) { + // Learn the new type + if (canon != typeToCanonicalize && learn) { + bool errs = this->learn(canon, this, getChildContext()); + if (errs) return nullptr; + } + auto tt = new IR::Type_Type(canon); + setType(getOriginal(), tt); + setType(type, tt); + } + return canon; +} + +const IR::Node *TypeInference::postorder(IR::Type_Error *decl) { + (void)setTypeType(decl); + for (auto id : *decl->getDeclarations()) setType(id->getNode(), decl); + return decl; +} + +const IR::Node *TypeInference::postorder(IR::Type_Table *type) { + (void)setTypeType(type); + return type; +} + +const IR::Node *TypeInference::postorder(IR::Type_Type *type) { + BUG("Should never be found in IR: %1%", type); +} + +const IR::Node *TypeInference::postorder(IR::P4Control *cont) { + (void)setTypeType(cont, false); + return cont; +} + +const IR::Node *TypeInference::postorder(IR::P4Parser *parser) { + (void)setTypeType(parser, false); + return parser; +} + +const IR::Node *TypeInference::postorder(IR::Type_InfInt *type) { + if (done()) return type; + auto tt = new IR::Type_Type(getOriginal()); + setType(getOriginal(), tt); + return type; +} + +const IR::Node *TypeInference::postorder(IR::Type_ArchBlock *decl) { + (void)setTypeType(decl); + return decl; +} + +const IR::Node *TypeInference::postorder(IR::Type_Package *decl) { + auto canon = setTypeType(decl); + if (canon != nullptr) { + for (auto p : decl->getConstructorParameters()->parameters) { + auto ptype = getType(p); + if (ptype == nullptr) + // error + return decl; + if (ptype->is() || ptype->is()) + typeError("%1%: Invalid package parameter type", p); + } + } + return decl; +} + +class ContainsType : public Inspector { + const IR::Type *contained; + const TypeMap *typeMap; + const IR::Type *found = nullptr; + + ContainsType(const IR::Type *contained, const TypeMap *typeMap) + : contained(contained), typeMap(typeMap) { + CHECK_NULL(contained); + CHECK_NULL(typeMap); + } + + bool preorder(const IR::Type *type) override { + LOG3("ContainsType " << type); + if (typeMap->equivalent(type, contained)) found = type; + return true; + } + + public: + static const IR::Type *find(const IR::Type *type, const IR::Type *contained, + const TypeMap *typeMap) { + ContainsType c(contained, typeMap); + LOG3("Checking if " << type << " contains " << contained); + type->apply(c); + return c.found; + } +}; + +const IR::Node *TypeInference::postorder(IR::Type_Specialized *type) { + // Check for recursive type specializations, e.g., + // extern e {}; e> x; + auto baseType = getTypeType(type->baseType); + if (!baseType) return type; + for (auto arg : *type->arguments) { + auto argtype = getTypeType(arg); + if (!argtype) return type; + if (auto self = ContainsType::find(argtype, baseType, typeMap)) { + typeError("%1%: contains self '%2%' as type argument", type->baseType, self); + return type; + } + if (auto tg = argtype->to()) { + if (tg->getTypeParameters()->size() != 0) { + typeError("%1%: generic type needs type arguments", arg); + return type; + } + } + } + (void)setTypeType(type); + return type; +} + +const IR::Node *TypeInference::postorder(IR::Type_SpecializedCanonical *type) { + (void)setTypeType(type); + return type; +} + +const IR::Node *TypeInference::postorder(IR::Type_Name *typeName) { + if (done()) return typeName; + const IR::Type *type; + + if (typeName->path->isDontCare()) { + auto t = IR::Type_Dontcare::get(); + type = new IR::Type_Type(t); + } else { + auto decl = getDeclaration(typeName->path, !errorOnNullDecls); + if (errorOnNullDecls && decl == nullptr) { + typeError("%1%: Cannot resolve type", typeName); + return typeName; + } + // Check for references of a control or parser within itself. + auto ctrl = findContext(); + if (ctrl != nullptr && ctrl->name == decl->getName()) { + typeError("%1%: Cannot refer to control inside itself", typeName); + return typeName; + } + auto parser = findContext(); + if (parser != nullptr && parser->name == decl->getName()) { + typeError("%1%: Cannot refer parser inside itself", typeName); + return typeName; + } + + type = getType(decl->getNode()); + if (type == nullptr) return typeName; + BUG_CHECK(type->is(), "%1%: should be a Type_Type", type); + } + setType(typeName->path, type->to()->type); + setType(getOriginal(), type); + setType(typeName, type); + return typeName; +} + +const IR::Node *TypeInference::postorder(IR::Type_ActionEnum *type) { + (void)setTypeType(type); + return type; +} + +const IR::Node *TypeInference::postorder(IR::Type_Enum *type) { + auto canon = setTypeType(type); + for (auto e : *type->getDeclarations()) setType(e->getNode(), canon); + return type; +} + +const IR::Node *TypeInference::preorder(IR::Type_SerEnum *type) { + auto canon = setTypeType(type); + for (auto e : *type->getDeclarations()) setType(e->getNode(), canon); + return type; +} + +const IR::Node *TypeInference::postorder(IR::Type_Var *typeVar) { + if (done()) return typeVar; + const IR::Type *type; + if (typeVar->name.isDontCare()) + type = IR::Type_Dontcare::get(); + else + type = getOriginal(); + auto tt = new IR::Type_Type(type); + setType(getOriginal(), tt); + setType(typeVar, tt); + return typeVar; +} + +const IR::Node *TypeInference::postorder(IR::Type_List *type) { + (void)setTypeType(type); + return type; +} + +const IR::Node *TypeInference::postorder(IR::Type_Tuple *type) { + for (auto field : type->components) { + auto fieldType = getTypeType(field); + if (auto spec = fieldType->to()) fieldType = spec->baseType; + if (fieldType->is() || fieldType->is() || + fieldType->is()) { + typeError("%1%: not supported as a tuple field", field); + return type; + } + } + (void)setTypeType(type); + return type; +} + +const IR::Node *TypeInference::postorder(IR::Type_P4List *type) { + (void)setTypeType(type); + return type; +} + +const IR::Node *TypeInference::postorder(IR::Type_Set *type) { + (void)setTypeType(type); + return type; +} + +/// get size int bits required to represent given constant +static int getConstantsRepresentationSize(big_int val, bool isSigned) { + if (val < 0) { + val = -val; + } + int cnt = 0; + while (val > 0) { + ++cnt; + val >>= 1; + } + return cnt + int(isSigned); +} + +const IR::Type_Bits *TypeInference::checkUnderlyingEnumType(const IR::Type *enumType) { + const auto *resolvedType = getTypeType(enumType); + CHECK_NULL(resolvedType); + if (const auto *type = resolvedType->to()) { + return type; + } + std::string note; + if (resolvedType->is()) { + note = "; note that the used type is unsized integral type"; + } else if (resolvedType->is()) { + note = "; note that type-declared types are not allowed even if they are fixed-size"; + } + typeError("%1%: Illegal type for enum; only bit<> and int<> are allowed%2%", enumType, note); + return nullptr; +} + +/// Check if the value initializer fits into the underlying enum type. Emits error and returns false +/// if it does not fit. Returns true if it fits. +static bool checkEnumValueInitializer(const IR::Type_Bits *type, const IR::Expression *initializer, + const IR::Type_SerEnum *serEnum, + const IR::SerEnumMember *member) { + // validate the constant fits -- non-fitting enum constants should produce error + if (const auto *constant = initializer->to()) { + // signed values are two's complement, so [-2^(n-1)..2^(n-1)-1] + big_int low = type->isSigned ? -(big_int(1) << type->size - 1) : big_int(0); + big_int high = (big_int(1) << (type->isSigned ? type->size - 1 : type->size)) - 1; + + if (constant->value < low || constant->value > high) { + int required = getConstantsRepresentationSize(constant->value, type->isSigned); + std::string extraMsg; + if (!type->isSigned && constant->value < low) { + extraMsg = + str(boost::format( + "the value %1% is negative, but the underlying type %2% is unsigned") % + constant->value % type->toString()); + } else { + extraMsg = + str(boost::format("the value %1% requires %2% bits but the underlying " + "%3% type %4% only contains %5% bits") % + constant->value % required % (type->isSigned ? "signed" : "unsigned") % + type->toString() % type->size); + } + ::error(ErrorType::ERR_TYPE_ERROR, + "%1%: Serialized enum constant value %2% is out of bounds of the underlying " + "type %3%; %4%", + member, constant->value, serEnum->type, extraMsg); + return false; + } + } + return true; +} + +const IR::Node *TypeInference::postorder(IR::SerEnumMember *member) { + /* + The type of the member is initially set in the Type_SerEnum preorder visitor. + Here we check additional constraints and we may correct the member. + if (done()) return member; + */ + const auto *serEnum = findContext(); + CHECK_NULL(serEnum); + const auto *type = checkUnderlyingEnumType(serEnum->type); + if (!type || !checkEnumValueInitializer(type, member->value, serEnum, member)) { + return member; + } + const auto *exprType = getType(member->value); + auto *tvs = unifyCast(member, type, exprType, + "Enum member '%1%' has type '%2%' and not the expected type '%3%'", + {member, exprType, type}); + if (tvs == nullptr) + // error already signalled + return member; + if (tvs->isIdentity()) return member; + + ConstantTypeSubstitution cts(tvs, typeMap, this); + member->value = cts.convert(member->value, getChildContext()); // sets type + if (!typeMap->getType(member)) setType(member, getTypeType(serEnum)); + return member; +} + +const IR::Node *TypeInference::postorder(IR::P4ValueSet *decl) { + if (done()) return decl; + // This is a specialized version of setTypeType + auto canon = canonicalize(decl->elementType); + if (canon != nullptr) { + if (canon != decl->elementType) { + bool errs = learn(canon, this, getChildContext()); + if (errs) return nullptr; + } + if (!canon->is() && !canon->is() && + !canon->is() && !canon->is() && + !canon->is() && !canon->is() && + !canon->is()) + typeError("%1%: Illegal type for value_set element type", decl->elementType); + + auto tt = new IR::Type_Set(canon); + setType(getOriginal(), tt); + setType(decl, tt); + } + return decl; +} + +const IR::Node *TypeInference::postorder(IR::Type_Extern *type) { + if (done()) return type; + setTypeType(type); + return type; +} + +const IR::Node *TypeInference::postorder(IR::Type_Method *type) { + auto methodType = type; + if (auto ext = findContext()) { + auto extName = ext->name.name; + if (auto method = findContext()) { + auto name = method->name.name; + if (methodType->returnType && (methodType->returnType->is() || + methodType->returnType->is())) + typeError("%1%: illegal return type for method", method->type->returnType); + if (name == extName) { + // This is a constructor. + if (this->called_by == nullptr && // canonical types violate this rule + method->type->typeParameters != nullptr && + method->type->typeParameters->size() > 0) { + typeError("%1%: Constructors cannot have type parameters", + method->type->typeParameters); + return type; + } + // For constructors we add the type variables of the + // enclosing extern as type parameters. Given + // extern e { e(); } + // the type of method e is in fact e(); + methodType = new IR::Type_Method(type->srcInfo, ext->typeParameters, + type->returnType, type->parameters, name); + } + } + } + (void)setTypeType(methodType); + return type; +} + +const IR::Node *TypeInference::postorder(IR::Type_Action *type) { + (void)setTypeType(type); + BUG_CHECK(type->typeParameters->size() == 0, "%1%: Generic action?", type); + return type; +} + +const IR::Node *TypeInference::postorder(IR::Type_Base *type) { + (void)setTypeType(type); + return type; +} + +const IR::Node *TypeInference::postorder(IR::Type_Newtype *type) { + (void)setTypeType(type); + auto argType = getTypeType(type->type); + if (!argType->is() && !argType->is() && + !argType->is()) + typeError("%1%: `type' can only be applied to base types", type); + return type; +} + +const IR::Node *TypeInference::postorder(IR::Type_Typedef *tdecl) { + if (done()) return tdecl; + auto type = getType(tdecl->type); + if (type == nullptr) return tdecl; + BUG_CHECK(type->is(), "%1%: expected a TypeType", type); + auto stype = type->to()->type; + if (auto gen = stype->to()) { + if (gen->getTypeParameters()->size() != 0) { + typeError("%1%: no type parameters supplied for generic type", tdecl->type); + return tdecl; + } + } + setType(getOriginal(), type); + setType(tdecl, type); + return tdecl; +} + +const IR::Node *TypeInference::postorder(IR::Type_Stack *type) { + auto canon = setTypeType(type); + if (canon == nullptr) return type; + + auto etype = canon->to()->elementType; + if (etype == nullptr) return type; + + if (!etype->is() && !etype->is() && + !etype->is()) + typeError("Header stack %1% used with non-header type %2%", type, etype->toString()); + return type; +} + +/// Validate the fields of a struct type using the supplied checker. +/// The checker returns "false" when a field is invalid. +/// Return true on success +bool TypeInference::validateFields(const IR::Type *type, + std::function checker) const { + if (type == nullptr) return false; + BUG_CHECK(type->is(), "%1%; expected a Struct-like", type); + auto strct = type->to(); + bool err = false; + for (auto field : strct->fields) { + auto ftype = getType(field); + if (ftype == nullptr) return false; + if (!checker(ftype)) { + typeError("Field '%1%' of '%2%' cannot have type '%3%'", field, type->toString(), + field->type); + err = true; + } + } + return !err; +} + +const IR::Node *TypeInference::postorder(IR::StructField *field) { + if (done()) return field; + auto canon = getTypeType(field->type); + if (canon == nullptr) return field; + + setType(getOriginal(), canon); + setType(field, canon); + return field; +} + +const IR::Node *TypeInference::postorder(IR::Type_Header *type) { + auto canon = setTypeType(type); + auto validator = [this](const IR::Type *t) { + while (t->is()) t = getTypeType(t->to()->type); + return t->is() || t->is() || + (t->is() && onlyBitsOrBitStructs(t)) || t->is() || + t->is() || t->is() || + t->is(); + }; + validateFields(canon, validator); + return type; +} + +const IR::Node *TypeInference::postorder(IR::Type_Struct *type) { + auto canon = setTypeType(type); + auto validator = [this](const IR::Type *t) { + while (auto *nt = t->to()) t = getTypeType(nt->type); + return t->is() || t->is() || t->is() || + t->is() || t->is() || t->is() || + t->is() || t->is() || t->is() || + t->is() || t->is() || + t->is() || t->is() || + t->is() || t->is(); + }; + (void)validateFields(canon, validator); + return type; +} + +const IR::Node *TypeInference::postorder(IR::Type_HeaderUnion *type) { + auto canon = setTypeType(type); + auto validator = [](const IR::Type *t) { + return t->is() || t->is() || + t->is(); + }; + (void)validateFields(canon, validator); + return type; +} + +} // namespace P4 diff --git a/frontends/p4/typeChecking/typeChecker.cpp b/frontends/p4/typeChecking/typeChecker.cpp index 811b7352e31..bd4f55944d3 100644 --- a/frontends/p4/typeChecking/typeChecker.cpp +++ b/frontends/p4/typeChecking/typeChecker.cpp @@ -16,83 +16,18 @@ limitations under the License. #include "typeChecker.h" -#include - -#include "absl/container/flat_hash_map.h" +#include "constantTypeSubstitution.h" #include "frontends/common/constantFolding.h" #include "frontends/common/resolveReferences/referenceMap.h" #include "frontends/common/resolveReferences/resolveReferences.h" -#include "frontends/p4/coreLibrary.h" -#include "frontends/p4/enumInstance.h" -#include "frontends/p4/methodInstance.h" -#include "frontends/p4/toP4/toP4.h" -#include "lib/algorithm.h" #include "lib/cstring.h" -#include "lib/hash.h" #include "lib/log.h" -#include "syntacticEquivalence.h" #include "typeConstraints.h" #include "typeSubstitution.h" #include "typeUnification.h" namespace P4 { -namespace { -// Used to set the type of Constants after type inference -class ConstantTypeSubstitution : public Transform, ResolutionContext { - TypeVariableSubstitution *subst; - TypeMap *typeMap; - TypeInference *tc; - - public: - ConstantTypeSubstitution(TypeVariableSubstitution *subst, TypeMap *typeMap, TypeInference *tc) - : subst(subst), typeMap(typeMap), tc(tc) { - CHECK_NULL(subst); - CHECK_NULL(typeMap); - CHECK_NULL(tc); - LOG3("ConstantTypeSubstitution " << subst); - } - - const IR::Node *postorder(IR::Constant *cst) override { - auto cstType = typeMap->getType(getOriginal(), true); - if (!cstType->is()) return cst; - auto repl = cstType; - while (repl->is()) { - auto next = subst->get(repl->to()); - BUG_CHECK(next != repl, "Cycle in substitutions: %1%", next); - if (!next) break; - repl = next; - } - if (repl != cstType) { - // We may replace a type variable with another one - LOG2("Inferred type " << repl << " for " << cst); - cst = new IR::Constant(cst->srcInfo, repl, cst->value, cst->base); - } else { - LOG2("No type inferred for " << cst << " repl is " << repl); - } - return cst; - } - - const IR::Expression *convert(const IR::Expression *expr, const Visitor::Context *ctxt) { - auto result = expr->apply(*this, ctxt)->to(); - if (result != expr && (::errorCount() == 0)) tc->learn(result, this, ctxt); - return result; - } - const IR::Vector *convert(const IR::Vector *vec, - const Visitor::Context *ctxt) { - auto result = vec->apply(*this, ctxt)->to>(); - if (result != vec) tc->learn(result, this, ctxt); - return result; - } - const IR::Vector *convert(const IR::Vector *vec, - const Visitor::Context *ctxt) { - auto result = vec->apply(*this, ctxt)->to>(); - if (result != vec) tc->learn(result, this, ctxt); - return result; - } -}; -} // namespace - TypeChecking::TypeChecking(ReferenceMap *refMap, TypeMap *typeMap, bool updateExpressions) { addPasses({new P4::TypeInference(typeMap, /* readOnly */ true, /* checkArrays */ true, /* errorOnNullDecls */ true), @@ -253,9 +188,9 @@ TypeVariableSubstitution *TypeInference::unifyBase( return tvs; } -const IR::Type *TypeInference::canonicalizeFields( - const IR::Type_StructLike *type, - std::function *)> constructor) { +template +const IR::Type *TypeInference::canonicalizeFields(const IR::Type_StructLike *type, + Ctor constructor) { bool changes = false; auto fields = new IR::IndexedVector(); for (auto field : type->fields) { @@ -272,54 +207,6 @@ const IR::Type *TypeInference::canonicalizeFields( return type; } -const IR::ParameterList *TypeInference::canonicalizeParameters(const IR::ParameterList *params) { - if (params == nullptr) return params; - - bool changes = false; - auto vec = new IR::IndexedVector(); - for (auto p : *params->getEnumerator()) { - auto paramType = getTypeType(p->type); - if (paramType == nullptr) return nullptr; - BUG_CHECK(!paramType->is(), "%1%: Unexpected parameter type", paramType); - if (paramType != p->type) { - p = new IR::Parameter(p->srcInfo, p->name, p->annotations, p->direction, paramType, - p->defaultValue); - changes = true; - } - setType(p, paramType); - vec->push_back(p); - } - if (changes) - return new IR::ParameterList(params->srcInfo, *vec); - else - return params; -} - -bool TypeInference::checkParameters(const IR::ParameterList *paramList, bool forbidModules, - bool forbidPackage) const { - for (auto p : paramList->parameters) { - auto type = getType(p); - if (type == nullptr) return false; - if (auto ts = type->to()) type = ts->baseType; - if (forbidPackage && type->is()) { - typeError("%1%: parameter cannot be a package", p); - return false; - } - if (p->direction != IR::Direction::None && - (type->is() || type->is())) { - typeError("%1%: a parameter with type %2% cannot have a direction", p, type); - return false; - } - if ((forbidModules || p->direction != IR::Direction::None) && - (type->is() || type->is() || - type->is() || type->is())) { - typeError("%1%: parameter cannot have type %2%", p, type); - return false; - } - } - return true; -} - /** * Bind the parameters with the specified arguments. * For example, given a type @@ -605,97 +492,12 @@ const IR::Node *TypeInference::preorder(IR::P4Program *program) { return program; } -const IR::Node *TypeInference::postorder(IR::Type_Error *decl) { - (void)setTypeType(decl); - for (auto id : *decl->getDeclarations()) setType(id->getNode(), decl); - return decl; -} - const IR::Node *TypeInference::postorder(IR::Declaration_MatchKind *decl) { if (done()) return decl; for (auto id : *decl->getDeclarations()) setType(id->getNode(), IR::Type_MatchKind::get()); return decl; } -const IR::Node *TypeInference::postorder(IR::Type_Table *type) { - (void)setTypeType(type); - return type; -} - -const IR::Node *TypeInference::postorder(IR::P4Table *table) { - currentActionList = nullptr; - if (done()) return table; - auto type = new IR::Type_Table(table); - setType(getOriginal(), type); - setType(table, type); - return table; -} - -const IR::Node *TypeInference::postorder(IR::P4Action *action) { - if (done()) return action; - auto pl = canonicalizeParameters(action->parameters); - if (pl == nullptr) return action; - if (!checkParameters(action->parameters, forbidModules, forbidPackages)) return action; - auto type = new IR::Type_Action(new IR::TypeParameters(), nullptr, pl); - - bool foundDirectionless = false; - for (auto p : action->parameters->parameters) { - auto ptype = getType(p); - BUG_CHECK(ptype, "%1%: parameter type missing when it was found previously", p); - if (ptype->is()) - typeError("%1%: Action parameters cannot have extern types", p->type); - if (p->direction == IR::Direction::None) - foundDirectionless = true; - else if (foundDirectionless) - typeError("%1%: direction-less action parameters have to be at the end", p); - } - setType(getOriginal(), type); - setType(action, type); - return action; -} - -const IR::Node *TypeInference::postorder(IR::Declaration_Variable *decl) { - if (done()) return decl; - auto type = getTypeType(decl->type); - if (type == nullptr) return decl; - - if (const IR::IMayBeGenericType *gt = type->to()) { - // Check that there are no unbound type parameters - if (!gt->getTypeParameters()->empty()) { - typeError("Unspecified type parameters for %1% in %2%", gt, decl); - return decl; - } - } - - const IR::Type *baseType = type; - if (auto sc = type->to()) baseType = sc->baseType; - if (baseType->is() || baseType->is() || - baseType->is() || baseType->is()) { - typeError("%1%: cannot declare variables of type '%2%' (consider using an instantiation)", - decl, type); - return decl; - } - - if (type->is() || type->is()) { - typeError("%1%: Cannot declare variables with type %2%", decl, type); - return decl; - } - - auto orig = getOriginal(); - if (decl->initializer != nullptr) { - auto init = assignment(decl, type, decl->initializer); - if (decl->initializer != init) { - auto declType = type->getP4Type(); - decl->type = declType; - decl->initializer = init; - LOG2("Created new declaration " << decl); - } - } - setType(decl, type); - setType(orig, type); - return decl; -} - bool TypeInference::canCastBetween(const IR::Type *dest, const IR::Type *src) const { if (src->is()) return false; if (typeMap->equivalent(src, dest)) return true; @@ -918,28 +720,6 @@ const IR::Node *TypeInference::postorder(IR::Annotation *annotation) { return annotation; } -const IR::Node *TypeInference::postorder(IR::Declaration_Constant *decl) { - if (done()) return decl; - auto type = getTypeType(decl->type); - if (type == nullptr) return decl; - - if (type->is()) { - typeError("%1%: Cannot declare constants of extern types", decl->name); - return decl; - } - - if (!isCompileTimeConstant(decl->initializer)) - typeError("%1%: Cannot evaluate initializer to a compile-time constant", decl->initializer); - auto orig = getOriginal(); - auto newInit = assignment(decl, type, decl->initializer); - if (newInit != decl->initializer) - decl = new IR::Declaration_Constant(decl->srcInfo, decl->name, decl->annotations, - decl->type, newInit); - setType(decl, type); - setType(orig, type); - return decl; -} - // Returns the type of the constructed object and // new arguments for constructor, which may have inserted casts. std::pair *> TypeInference::checkExternConstructor( @@ -1060,126 +840,6 @@ std::pair *> TypeInference::che return {objectType, arguments}; } -// Return true on success -bool TypeInference::checkAbstractMethods(const IR::Declaration_Instance *inst, - const IR::Type_Extern *type) { - // Make a list of the abstract methods - IR::NameMap virt; - for (auto m : type->methods) - if (m->isAbstract) virt.addUnique(m->name, m); - if (virt.size() == 0 && inst->initializer == nullptr) return true; - if (virt.size() == 0 && inst->initializer != nullptr) { - typeError("%1%: instance initializers for extern without abstract methods", - inst->initializer); - return false; - } else if (virt.size() != 0 && inst->initializer == nullptr) { - typeError("%1%: must declare abstract methods for %2%", inst, type); - return false; - } - - for (auto d : inst->initializer->components) { - if (auto *func = d->to()) { - LOG2("Type checking " << func); - if (func->type->typeParameters->size() != 0) { - typeError("%1%: abstract method implementations cannot be generic", func); - return false; - } - auto ftype = getType(func); - if (virt.find(func->name.name) == virt.end()) { - typeError("%1%: no matching abstract method in %2%", func, type); - return false; - } - auto meth = virt[func->name.name]; - auto methtype = getType(meth); - virt.erase(func->name.name); - auto tvs = - unify(inst, methtype, ftype, "Method '%1%' does not have the expected type '%2%'", - {func, methtype}); - if (tvs == nullptr) return false; - BUG_CHECK(errorCount() > 0 || tvs->isIdentity(), "%1%: expected no type variables", - tvs); - } - } - bool rv = true; - for (auto &vm : virt) { - if (!vm.second->annotations->getSingle("optional"_cs)) { - typeError("%1%: %2% abstract method not implemented", inst, vm.second); - rv = false; - } - } - return rv; -} - -const IR::Node *TypeInference::preorder(IR::Declaration_Instance *decl) { - // We need to control the order of the type-checking: we want to do first - // the declaration, and then typecheck the initializer if present. - if (done()) return decl; - visit(decl->type, "type"); - visit(decl->arguments, "arguments"); - visit(decl->annotations, "annotations"); - visit(decl->properties, "properties"); - - auto type = getTypeType(decl->type); - if (type == nullptr) { - prune(); - return decl; - } - auto orig = getOriginal(); - - auto simpleType = type; - if (auto *sc = type->to()) simpleType = sc->substituted; - - if (auto et = simpleType->to()) { - auto [newType, newArgs] = checkExternConstructor(decl, et, decl->arguments); - if (newArgs == nullptr) { - prune(); - return decl; - } - // type can be Type_Extern or Type_SpecializedCanonical. If it is already - // specialized, the type arguments were specified explicitly. - // Otherwise, we use the type received from checkExternConstructor, which - // has substituted the type variables with fresh ones. - if (type->is()) type = newType; - decl->arguments = newArgs; - setType(orig, type); - setType(decl, type); - - if (decl->initializer != nullptr) visit(decl->initializer); - // This will need the decl type to be already known - bool s = checkAbstractMethods(decl, et); - if (!s) { - prune(); - return decl; - } - } else if (simpleType->is()) { - if (decl->initializer != nullptr) { - typeError("%1%: initializers only allowed for extern instances", decl->initializer); - prune(); - return decl; - } - if (!simpleType->is() && (findContext() == nullptr)) { - ::error(ErrorType::ERR_INVALID, "%1%: cannot instantiate at top-level", decl); - return decl; - } - auto typeAndArgs = - containerInstantiation(decl, decl->arguments, simpleType->to()); - auto type = typeAndArgs.first; - auto args = typeAndArgs.second; - if (type == nullptr || args == nullptr) { - prune(); - return decl; - } - learn(type, this, getChildContext()); - if (args != decl->arguments) decl->arguments = args; - setType(decl, type); - setType(orig, type); - } else { - typeError("%1%: cannot allocate objects of type %2%", decl, type); - } - prune(); - return decl; -} - /// @returns: A pair containing the type returned by the constructor and the new arguments /// (which may change due to insertion of casts). std::pair *> TypeInference::containerInstantiation( @@ -1264,18 +924,6 @@ std::pair *> TypeInference::con return std::pair *>(returnType, newArgs); } -const IR::Node *TypeInference::preorder(IR::Function *function) { - if (done()) return function; - visit(function->type); - auto type = getTypeType(function->type); - if (type == nullptr) return function; - setType(getOriginal(), type); - setType(function, type); - visit(function->body); - prune(); - return function; -} - const IR::Node *TypeInference::postorder(IR::Argument *arg) { if (done()) return arg; auto type = getType(arg->expression); @@ -1285,3104 +933,4 @@ const IR::Node *TypeInference::postorder(IR::Argument *arg) { return arg; } -const IR::Node *TypeInference::postorder(IR::Method *method) { - if (done()) return method; - auto type = getTypeType(method->type); - if (type == nullptr) return method; - if (auto mtype = type->to()) { - if (mtype->returnType) { - if (auto gen = mtype->returnType->to()) { - if (gen->getTypeParameters()->size() != 0) { - typeError("%1%: no type parameters supplied for return generic type", - method->type->returnType); - return method; - } - } - } - } - setType(getOriginal(), type); - setType(method, type); - return method; -} - -////////////////////////////////////////////// Types - -const IR::Type *TypeInference::setTypeType(const IR::Type *type, bool learn) { - if (done()) return type; - const IR::Type *typeToCanonicalize; - if (readOnly) - typeToCanonicalize = getOriginal(); - else - typeToCanonicalize = type; - auto canon = canonicalize(typeToCanonicalize); - if (canon != nullptr) { - // Learn the new type - if (canon != typeToCanonicalize && learn) { - bool errs = this->learn(canon, this, getChildContext()); - if (errs) return nullptr; - } - auto tt = new IR::Type_Type(canon); - setType(getOriginal(), tt); - setType(type, tt); - } - return canon; -} - -const IR::Node *TypeInference::postorder(IR::Type_Type *type) { - BUG("Should never be found in IR: %1%", type); -} - -const IR::Node *TypeInference::postorder(IR::P4Control *cont) { - (void)setTypeType(cont, false); - return cont; -} - -const IR::Node *TypeInference::postorder(IR::P4Parser *parser) { - (void)setTypeType(parser, false); - return parser; -} - -const IR::Node *TypeInference::postorder(IR::Type_InfInt *type) { - if (done()) return type; - auto tt = new IR::Type_Type(getOriginal()); - setType(getOriginal(), tt); - return type; -} - -const IR::Node *TypeInference::postorder(IR::Type_ArchBlock *decl) { - (void)setTypeType(decl); - return decl; -} - -const IR::Node *TypeInference::postorder(IR::Type_Package *decl) { - auto canon = setTypeType(decl); - if (canon != nullptr) { - for (auto p : decl->getConstructorParameters()->parameters) { - auto ptype = getType(p); - if (ptype == nullptr) - // error - return decl; - if (ptype->is() || ptype->is()) - typeError("%1%: Invalid package parameter type", p); - } - } - return decl; -} - -class ContainsType : public Inspector { - const IR::Type *contained; - const TypeMap *typeMap; - const IR::Type *found = nullptr; - - ContainsType(const IR::Type *contained, const TypeMap *typeMap) - : contained(contained), typeMap(typeMap) { - CHECK_NULL(contained); - CHECK_NULL(typeMap); - } - - bool preorder(const IR::Type *type) override { - LOG3("ContainsType " << type); - if (typeMap->equivalent(type, contained)) found = type; - return true; - } - - public: - static const IR::Type *find(const IR::Type *type, const IR::Type *contained, - const TypeMap *typeMap) { - ContainsType c(contained, typeMap); - LOG3("Checking if " << type << " contains " << contained); - type->apply(c); - return c.found; - } -}; - -const IR::Node *TypeInference::postorder(IR::Type_Specialized *type) { - // Check for recursive type specializations, e.g., - // extern e {}; e> x; - auto baseType = getTypeType(type->baseType); - if (!baseType) return type; - for (auto arg : *type->arguments) { - auto argtype = getTypeType(arg); - if (!argtype) return type; - if (auto self = ContainsType::find(argtype, baseType, typeMap)) { - typeError("%1%: contains self '%2%' as type argument", type->baseType, self); - return type; - } - if (auto tg = argtype->to()) { - if (tg->getTypeParameters()->size() != 0) { - typeError("%1%: generic type needs type arguments", arg); - return type; - } - } - } - (void)setTypeType(type); - return type; -} - -const IR::Node *TypeInference::postorder(IR::Type_SpecializedCanonical *type) { - (void)setTypeType(type); - return type; -} - -const IR::Node *TypeInference::postorder(IR::Type_Name *typeName) { - if (done()) return typeName; - const IR::Type *type; - - if (typeName->path->isDontCare()) { - auto t = IR::Type_Dontcare::get(); - type = new IR::Type_Type(t); - } else { - auto decl = getDeclaration(typeName->path, !errorOnNullDecls); - if (errorOnNullDecls && decl == nullptr) { - typeError("%1%: Cannot resolve type", typeName); - return typeName; - } - // Check for references of a control or parser within itself. - auto ctrl = findContext(); - if (ctrl != nullptr && ctrl->name == decl->getName()) { - typeError("%1%: Cannot refer to control inside itself", typeName); - return typeName; - } - auto parser = findContext(); - if (parser != nullptr && parser->name == decl->getName()) { - typeError("%1%: Cannot refer parser inside itself", typeName); - return typeName; - } - - type = getType(decl->getNode()); - if (type == nullptr) return typeName; - BUG_CHECK(type->is(), "%1%: should be a Type_Type", type); - } - setType(typeName->path, type->to()->type); - setType(getOriginal(), type); - setType(typeName, type); - return typeName; -} - -const IR::Node *TypeInference::postorder(IR::Type_ActionEnum *type) { - (void)setTypeType(type); - return type; -} - -const IR::Node *TypeInference::postorder(IR::Type_Enum *type) { - auto canon = setTypeType(type); - for (auto e : *type->getDeclarations()) setType(e->getNode(), canon); - return type; -} - -const IR::Node *TypeInference::preorder(IR::Type_SerEnum *type) { - auto canon = setTypeType(type); - for (auto e : *type->getDeclarations()) setType(e->getNode(), canon); - return type; -} - -const IR::Node *TypeInference::postorder(IR::Type_Var *typeVar) { - if (done()) return typeVar; - const IR::Type *type; - if (typeVar->name.isDontCare()) - type = IR::Type_Dontcare::get(); - else - type = getOriginal(); - auto tt = new IR::Type_Type(type); - setType(getOriginal(), tt); - setType(typeVar, tt); - return typeVar; -} - -const IR::Node *TypeInference::postorder(IR::Type_List *type) { - (void)setTypeType(type); - return type; -} - -const IR::Node *TypeInference::postorder(IR::Type_Tuple *type) { - for (auto field : type->components) { - auto fieldType = getTypeType(field); - if (auto spec = fieldType->to()) fieldType = spec->baseType; - if (fieldType->is() || fieldType->is() || - fieldType->is()) { - typeError("%1%: not supported as a tuple field", field); - return type; - } - } - (void)setTypeType(type); - return type; -} - -const IR::Node *TypeInference::postorder(IR::Type_P4List *type) { - (void)setTypeType(type); - return type; -} - -const IR::Node *TypeInference::postorder(IR::Type_Set *type) { - (void)setTypeType(type); - return type; -} - -/// get size int bits required to represent given constant -static int getConstantsRepresentationSize(big_int val, bool isSigned) { - if (val < 0) { - val = -val; - } - int cnt = 0; - while (val > 0) { - ++cnt; - val >>= 1; - } - return cnt + int(isSigned); -} - -const IR::Type_Bits *TypeInference::checkUnderlyingEnumType(const IR::Type *enumType) { - const auto *resolvedType = getTypeType(enumType); - CHECK_NULL(resolvedType); - if (const auto *type = resolvedType->to()) { - return type; - } - std::string note; - if (resolvedType->is()) { - note = "; note that the used type is unsized integral type"; - } else if (resolvedType->is()) { - note = "; note that type-declared types are not allowed even if they are fixed-size"; - } - typeError("%1%: Illegal type for enum; only bit<> and int<> are allowed%2%", enumType, note); - return nullptr; -} - -/// Check if the value initializer fits into the underlying enum type. Emits error and returns false -/// if it does not fit. Returns true if it fits. -static bool checkEnumValueInitializer(const IR::Type_Bits *type, const IR::Expression *initializer, - const IR::Type_SerEnum *serEnum, - const IR::SerEnumMember *member) { - // validate the constant fits -- non-fitting enum constants should produce error - if (const auto *constant = initializer->to()) { - // signed values are two's complement, so [-2^(n-1)..2^(n-1)-1] - big_int low = type->isSigned ? -(big_int(1) << type->size - 1) : big_int(0); - big_int high = (big_int(1) << (type->isSigned ? type->size - 1 : type->size)) - 1; - - if (constant->value < low || constant->value > high) { - int required = getConstantsRepresentationSize(constant->value, type->isSigned); - std::string extraMsg; - if (!type->isSigned && constant->value < low) { - extraMsg = - str(boost::format( - "the value %1% is negative, but the underlying type %2% is unsigned") % - constant->value % type->toString()); - } else { - extraMsg = - str(boost::format("the value %1% requires %2% bits but the underlying " - "%3% type %4% only contains %5% bits") % - constant->value % required % (type->isSigned ? "signed" : "unsigned") % - type->toString() % type->size); - } - ::error(ErrorType::ERR_TYPE_ERROR, - "%1%: Serialized enum constant value %2% is out of bounds of the underlying " - "type %3%; %4%", - member, constant->value, serEnum->type, extraMsg); - return false; - } - } - return true; -} - -const IR::Node *TypeInference::postorder(IR::SerEnumMember *member) { - /* - The type of the member is initially set in the Type_SerEnum preorder visitor. - Here we check additional constraints and we may correct the member. - if (done()) return member; - */ - const auto *serEnum = findContext(); - CHECK_NULL(serEnum); - const auto *type = checkUnderlyingEnumType(serEnum->type); - if (!type || !checkEnumValueInitializer(type, member->value, serEnum, member)) { - return member; - } - const auto *exprType = getType(member->value); - auto *tvs = unifyCast(member, type, exprType, - "Enum member '%1%' has type '%2%' and not the expected type '%3%'", - {member, exprType, type}); - if (tvs == nullptr) - // error already signalled - return member; - if (tvs->isIdentity()) return member; - - ConstantTypeSubstitution cts(tvs, typeMap, this); - member->value = cts.convert(member->value, getChildContext()); // sets type - if (!typeMap->getType(member)) setType(member, getTypeType(serEnum)); - return member; -} - -const IR::Node *TypeInference::postorder(IR::P4ValueSet *decl) { - if (done()) return decl; - // This is a specialized version of setTypeType - auto canon = canonicalize(decl->elementType); - if (canon != nullptr) { - if (canon != decl->elementType) { - bool errs = learn(canon, this, getChildContext()); - if (errs) return nullptr; - } - if (!canon->is() && !canon->is() && - !canon->is() && !canon->is() && - !canon->is() && !canon->is() && - !canon->is()) - typeError("%1%: Illegal type for value_set element type", decl->elementType); - - auto tt = new IR::Type_Set(canon); - setType(getOriginal(), tt); - setType(decl, tt); - } - return decl; -} - -const IR::Node *TypeInference::postorder(IR::Type_Extern *type) { - if (done()) return type; - setTypeType(type); - return type; -} - -const IR::Node *TypeInference::postorder(IR::Type_Method *type) { - auto methodType = type; - if (auto ext = findContext()) { - auto extName = ext->name.name; - if (auto method = findContext()) { - auto name = method->name.name; - if (methodType->returnType && (methodType->returnType->is() || - methodType->returnType->is())) - typeError("%1%: illegal return type for method", method->type->returnType); - if (name == extName) { - // This is a constructor. - if (this->called_by == nullptr && // canonical types violate this rule - method->type->typeParameters != nullptr && - method->type->typeParameters->size() > 0) { - typeError("%1%: Constructors cannot have type parameters", - method->type->typeParameters); - return type; - } - // For constructors we add the type variables of the - // enclosing extern as type parameters. Given - // extern e { e(); } - // the type of method e is in fact e(); - methodType = new IR::Type_Method(type->srcInfo, ext->typeParameters, - type->returnType, type->parameters, name); - } - } - } - (void)setTypeType(methodType); - return type; -} - -const IR::Node *TypeInference::postorder(IR::Type_Action *type) { - (void)setTypeType(type); - BUG_CHECK(type->typeParameters->size() == 0, "%1%: Generic action?", type); - return type; -} - -const IR::Node *TypeInference::postorder(IR::Type_Base *type) { - (void)setTypeType(type); - return type; -} - -const IR::Node *TypeInference::postorder(IR::Type_Newtype *type) { - (void)setTypeType(type); - auto argType = getTypeType(type->type); - if (!argType->is() && !argType->is() && - !argType->is()) - typeError("%1%: `type' can only be applied to base types", type); - return type; -} - -const IR::Node *TypeInference::postorder(IR::Type_Typedef *tdecl) { - if (done()) return tdecl; - auto type = getType(tdecl->type); - if (type == nullptr) return tdecl; - BUG_CHECK(type->is(), "%1%: expected a TypeType", type); - auto stype = type->to()->type; - if (auto gen = stype->to()) { - if (gen->getTypeParameters()->size() != 0) { - typeError("%1%: no type parameters supplied for generic type", tdecl->type); - return tdecl; - } - } - setType(getOriginal(), type); - setType(tdecl, type); - return tdecl; -} - -const IR::Node *TypeInference::postorder(IR::Type_Stack *type) { - auto canon = setTypeType(type); - if (canon == nullptr) return type; - - auto etype = canon->to()->elementType; - if (etype == nullptr) return type; - - if (!etype->is() && !etype->is() && - !etype->is()) - typeError("Header stack %1% used with non-header type %2%", type, etype->toString()); - return type; -} - -/// Validate the fields of a struct type using the supplied checker. -/// The checker returns "false" when a field is invalid. -/// Return true on success -bool TypeInference::validateFields(const IR::Type *type, - std::function checker) const { - if (type == nullptr) return false; - BUG_CHECK(type->is(), "%1%; expected a Struct-like", type); - auto strct = type->to(); - bool err = false; - for (auto field : strct->fields) { - auto ftype = getType(field); - if (ftype == nullptr) return false; - if (!checker(ftype)) { - typeError("Field '%1%' of '%2%' cannot have type '%3%'", field, type->toString(), - field->type); - err = true; - } - } - return !err; -} - -const IR::Node *TypeInference::postorder(IR::StructField *field) { - if (done()) return field; - auto canon = getTypeType(field->type); - if (canon == nullptr) return field; - - setType(getOriginal(), canon); - setType(field, canon); - return field; -} - -const IR::Node *TypeInference::postorder(IR::Type_Header *type) { - auto canon = setTypeType(type); - auto validator = [this](const IR::Type *t) { - while (t->is()) t = getTypeType(t->to()->type); - return t->is() || t->is() || - (t->is() && onlyBitsOrBitStructs(t)) || t->is() || - t->is() || t->is() || - t->is(); - }; - validateFields(canon, validator); - return type; -} - -const IR::Node *TypeInference::postorder(IR::Type_Struct *type) { - auto canon = setTypeType(type); - auto validator = [this](const IR::Type *t) { - while (auto *nt = t->to()) t = getTypeType(nt->type); - return t->is() || t->is() || t->is() || - t->is() || t->is() || t->is() || - t->is() || t->is() || t->is() || - t->is() || t->is() || - t->is() || t->is() || - t->is() || t->is(); - }; - (void)validateFields(canon, validator); - return type; -} - -const IR::Node *TypeInference::postorder(IR::Type_HeaderUnion *type) { - auto canon = setTypeType(type); - auto validator = [](const IR::Type *t) { - return t->is() || t->is() || - t->is(); - }; - (void)validateFields(canon, validator); - return type; -} - -////////////////////////////////////////////////// expressions - -const IR::Node *TypeInference::postorder(IR::Parameter *param) { - if (done()) return param; - const IR::Type *paramType = getTypeType(param->type); - if (paramType == nullptr) return param; - BUG_CHECK(!paramType->is(), "%1%: unexpected type", paramType); - - if (paramType->is() || paramType->is()) { - typeError("%1%: parameter cannot have type %2%", param, paramType); - return param; - } - - if (!readOnly && paramType->is()) { - // We only give these errors if we are no in 'readOnly' mode: - // this prevents giving a confusing error message to the user. - if (param->direction != IR::Direction::None) { - typeError("%1%: parameters with type %2% must be directionless", param, paramType); - return param; - } - if (findContext()) { - typeError("%1%: actions cannot have parameters with type %2%", param, paramType); - return param; - } - } - - // The parameter type cannot have free type variables - if (auto *gen = paramType->to()) { - auto tp = gen->getTypeParameters(); - if (!tp->empty()) { - typeError("Type parameters needed for %1%", param->name); - return param; - } - } - - if (param->defaultValue) { - if (!typeMap->isCompileTimeConstant(param->defaultValue)) - typeError("%1%: expression must be a compile-time constant", param->defaultValue); - } - - setType(getOriginal(), paramType); - setType(param, paramType); - return param; -} - -const IR::Node *TypeInference::postorder(IR::Constant *expression) { - if (done()) return expression; - auto type = getTypeType(expression->type); - if (type == nullptr) return expression; - setType(getOriginal(), type); - setType(expression, type); - setCompileTimeConstant(getOriginal()); - setCompileTimeConstant(expression); - return expression; -} - -const IR::Node *TypeInference::postorder(IR::StringLiteral *expression) { - if (done()) return expression; - setType(getOriginal(), IR::Type_String::get()); - setType(expression, IR::Type_String::get()); - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - return expression; -} - -const IR::Node *TypeInference::postorder(IR::BoolLiteral *expression) { - if (done()) return expression; - setType(getOriginal(), IR::Type_Boolean::get()); - setType(expression, IR::Type_Boolean::get()); - setCompileTimeConstant(getOriginal()); - setCompileTimeConstant(expression); - return expression; -} - -bool TypeInference::containsActionEnum(const IR::Type *type) const { - if (auto st = type->to()) { - if (auto field = st->getField(IR::Type_Table::action_run)) { - auto ft = getTypeType(field->type); - if (ft->is()) return true; - } - } - return false; -} - -// Returns false on error -bool TypeInference::compare(const IR::Node *errorPosition, const IR::Type *ltype, - const IR::Type *rtype, Comparison *compare) { - if (ltype->is() || rtype->is()) { - // Actions return Type_Action instead of void. - typeError("%1% and %2% cannot be compared", compare->left, compare->right); - return false; - } - if (ltype->is() || rtype->is()) { - typeError("%1% and %2%: tables cannot be compared", compare->left, compare->right); - return false; - } - if (ltype->is() || rtype->is()) { - typeError("%1% and %2%: externs cannot be compared", compare->left, compare->right); - return false; - } - if (containsActionEnum(ltype) || containsActionEnum(rtype)) { - typeError("%1% and %2%: table application results cannot be compared", compare->left, - compare->right); - return false; - } - - bool defined = false; - if (typeMap->equivalent(ltype, rtype) && - (!ltype->is() && !ltype->is()) && - !ltype->to()) { - defined = true; - } else if (ltype->is() && rtype->is() && - typeMap->equivalent(ltype, rtype)) { - defined = true; - } else if (ltype->is() && rtype->is()) { - auto tvs = unify(errorPosition, ltype, rtype); - if (tvs == nullptr) return false; - if (!tvs->isIdentity()) { - ConstantTypeSubstitution cts(tvs, typeMap, this); - compare->left = cts.convert(compare->left, getChildContext()); - compare->right = cts.convert(compare->right, getChildContext()); - } - defined = true; - } else if (auto se = rtype->to()) { - // This can only happen in a switch statement, other comparisons - // eliminate SerEnums before calling here. - if (typeMap->equivalent(ltype, se->type)) defined = true; - } else { - auto ls = ltype->to(); - auto rs = rtype->to(); - if (ls != nullptr || rs != nullptr) { - if (ls != nullptr && rs != nullptr) { - typeError("%1%: cannot compare structure-valued expressions with unknown types", - errorPosition); - return false; - } - - bool lcst = isCompileTimeConstant(compare->left); - bool rcst = isCompileTimeConstant(compare->right); - TypeVariableSubstitution *tvs; - if (ls == nullptr) { - tvs = unify(errorPosition, ltype, rtype); - } else { - tvs = unify(errorPosition, rtype, ltype); - } - if (tvs == nullptr) return false; - if (!tvs->isIdentity()) { - ConstantTypeSubstitution cts(tvs, typeMap, this); - compare->left = cts.convert(compare->left, getChildContext()); - compare->right = cts.convert(compare->right, getChildContext()); - } - - if (ls != nullptr) { - auto l = compare->left->to(); - CHECK_NULL(l); // struct initializers are the only expressions that can - // have StructUnknown types - BUG_CHECK(rtype->is(), "%1%: expected a struct", rtype); - auto type = new IR::Type_Name(rtype->to()->name); - compare->left = - new IR::StructExpression(compare->left->srcInfo, type, type, l->components); - setType(compare->left, rtype); - if (lcst) setCompileTimeConstant(compare->left); - } else { - auto r = compare->right->to(); - CHECK_NULL(r); // struct initializers are the only expressions that can - // have StructUnknown types - BUG_CHECK(ltype->is(), "%1%: expected a struct", ltype); - auto type = new IR::Type_Name(ltype->to()->name); - compare->right = - new IR::StructExpression(compare->right->srcInfo, type, type, r->components); - setType(compare->right, rtype); - if (rcst) setCompileTimeConstant(compare->right); - } - defined = true; - } - - // comparison between structs and list expressions is allowed - if ((ltype->is() && rtype->is()) || - (ltype->is() && rtype->is())) { - if (!ltype->is()) { - // swap - auto type = ltype; - ltype = rtype; - rtype = type; - } - - auto tvs = unify(errorPosition, ltype, rtype); - if (tvs == nullptr) return false; - if (!tvs->isIdentity()) { - ConstantTypeSubstitution cts(tvs, typeMap, this); - compare->left = cts.convert(compare->left, getChildContext()); - compare->right = cts.convert(compare->right, getChildContext()); - } - defined = true; - } - } - - if (!defined) { - typeError("'%1%' with type '%2%' cannot be compared to '%3%' with type '%4%'", - compare->left, ltype, compare->right, rtype); - return false; - } - return true; -} - -const IR::Node *TypeInference::postorder(IR::Operation_Relation *expression) { - if (done()) return expression; - auto ltype = getType(expression->left); - auto rtype = getType(expression->right); - if (ltype == nullptr || rtype == nullptr) return expression; - - bool equTest = expression->is() || expression->is(); - if (auto l = ltype->to()) ltype = getTypeType(l->type); - if (auto r = rtype->to()) rtype = getTypeType(r->type); - BUG_CHECK(ltype && rtype, "Invalid Type_SerEnum/getTypeType"); - - if (ltype->is() && rtype->is()) { - // This can happen because we are replacing some constant functions with - // constants during type checking - auto result = constantFold(expression); - setType(getOriginal(), IR::Type_Boolean::get()); - setCompileTimeConstant(result); - setCompileTimeConstant(getOriginal()); - return result; - } else if (ltype->is() && rtype->is()) { - auto e = expression->clone(); - e->left = new IR::Cast(e->left->srcInfo, rtype, e->left); - setType(e->left, rtype); - ltype = rtype; - expression = e; - } else if (rtype->is() && ltype->is()) { - auto e = expression->clone(); - e->right = new IR::Cast(e->right->srcInfo, ltype, e->right); - setType(e->right, ltype); - rtype = ltype; - expression = e; - } - - if (equTest) { - Comparison c; - c.left = expression->left; - c.right = expression->right; - auto b = compare(expression, ltype, rtype, &c); - if (!b) return expression; - expression->left = c.left; - expression->right = c.right; - } else { - if (!ltype->is() || !rtype->is() || !(ltype->equiv(*rtype))) { - typeError("%1%: not defined on %2% and %3%", expression, ltype->toString(), - rtype->toString()); - return expression; - } - } - setType(getOriginal(), IR::Type_Boolean::get()); - setType(expression, IR::Type_Boolean::get()); - if (isCompileTimeConstant(expression->left) && isCompileTimeConstant(expression->right)) { - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - } - return expression; -} - -const IR::Node *TypeInference::postorder(IR::Concat *expression) { - if (done()) return expression; - auto ltype = getType(expression->left); - auto rtype = getType(expression->right); - if (ltype == nullptr || rtype == nullptr) return expression; - - if (ltype->is()) { - typeError("Please specify a width for the operand %1% of a concatenation", - expression->left); - return expression; - } - if (rtype->is()) { - typeError("Please specify a width for the operand %1% of a concatenation", - expression->right); - return expression; - } - - bool castLeft = false; - bool castRight = false; - if (auto se = ltype->to()) { - ltype = getTypeType(se->type); - castLeft = true; - } - if (auto se = rtype->to()) { - rtype = getTypeType(se->type); - castRight = true; - } - if (ltype == nullptr || rtype == nullptr) { - // getTypeType should have already taken care of the error message - return expression; - } - if (!ltype->is() || !rtype->is()) { - typeError("%1%: Concatenation not defined on %2% and %3%", expression, ltype->toString(), - rtype->toString()); - return expression; - } - auto bl = ltype->to(); - auto br = rtype->to(); - const IR::Type *resultType = IR::Type_Bits::get(bl->size + br->size, bl->isSigned); - - if (castLeft) { - auto e = expression->clone(); - e->left = new IR::Cast(e->left->srcInfo, bl, e->left); - if (isCompileTimeConstant(expression->left)) setCompileTimeConstant(e->left); - setType(e->left, ltype); - expression = e; - } - if (castRight) { - auto e = expression->clone(); - e->right = new IR::Cast(e->right->srcInfo, br, e->right); - if (isCompileTimeConstant(expression->right)) setCompileTimeConstant(e->right); - setType(e->right, rtype); - expression = e; - } - - resultType = canonicalize(resultType); - if (resultType != nullptr) { - setType(getOriginal(), resultType); - setType(expression, resultType); - if (isCompileTimeConstant(expression->left) && isCompileTimeConstant(expression->right)) { - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - } - } - return expression; -} - -/** - * compute the type of table keys. - * Used to typecheck pre-defined entries. - */ -const IR::Node *TypeInference::postorder(IR::Key *key) { - // compute the type and store it in typeMap - auto keyTuple = new IR::Type_Tuple; - for (auto ke : key->keyElements) { - auto kt = typeMap->getType(ke->expression); - if (kt == nullptr) { - LOG2("Bailing out for " << dbp(ke)); - return key; - } - keyTuple->components.push_back(kt); - } - LOG2("Setting key type to " << dbp(keyTuple)); - setType(key, keyTuple); - setType(getOriginal(), keyTuple); - return key; -} - -/** - * typecheck a table initializer entry list - */ -const IR::Node *TypeInference::preorder(IR::EntriesList *el) { - if (done()) return el; - auto table = findContext(); - BUG_CHECK(table != nullptr, "%1% entries not within a table", el); - const IR::Key *key = table->getKey(); - if (key == nullptr) { - if (el->size() != 0) - typeError("Entries cannot be specified for a table with no key %1%", table); - prune(); - return el; - } - auto keyTuple = typeMap->getType(key); // direct typeMap call to skip checks - if (keyTuple == nullptr) { - // The keys have to be before the entries list. If they are not, - // at this point they have not yet been type-checked. - if (key->srcInfo.isValid() && el->srcInfo.isValid() && key->srcInfo >= el->srcInfo) { - typeError("%1%: Entries list must be after table key %2%", el, key); - prune(); - return el; - } - // otherwise the type-checking of the keys must have failed - } - return el; -} - -/** - * typecheck a table initializer entry - * - * The invariants are: - * - table keys and entry keys must have the same length - * - entry key elements must be compile time constants - * - actionRefs in entries must be in the action list - * - table keys must have been type checked before entries - * - * Moreover, the EntriesList visitor should have checked for the table - * invariants. - */ -const IR::Node *TypeInference::postorder(IR::Entry *entry) { - if (done()) return entry; - auto table = findContext(); - if (table == nullptr) return entry; - const IR::Key *key = table->getKey(); - if (key == nullptr) return entry; - auto keyTuple = getType(key); - if (keyTuple == nullptr) return entry; - - auto entryKeyType = getType(entry->keys); - if (entryKeyType == nullptr) return entry; - if (auto ts = entryKeyType->to()) entryKeyType = ts->elementType; - if (entry->singleton) { - if (auto tl = entryKeyType->to()) { - // An entry of _ does not have type Tuple, but rather Type_Dontcare - if (tl->getSize() == 1 && tl->components.at(0)->is()) - entryKeyType = tl->components.at(0); - } - } - - auto keyset = entry->getKeys(); - if (keyset == nullptr || !(keyset->is())) { - typeError("%1%: key expression must be tuple", keyset); - return entry; - } - if (keyset->components.size() < key->keyElements.size()) { - typeError("%1%: Size of entry keyset must match the table key set size", keyset); - return entry; - } - - bool nonConstantKeys = false; - for (auto ke : keyset->components) - if (!isCompileTimeConstant(ke)) { - typeError("Key entry must be a compile time constant: %1%", ke); - nonConstantKeys = true; - } - if (nonConstantKeys) return entry; - - if (entry->priority && !isCompileTimeConstant(entry->priority)) { - typeError("Entry priority must be a compile time constant: %1%", entry->priority); - return entry; - } - - TypeVariableSubstitution *tvs = - unifyCast(entry, keyTuple, entryKeyType, - "Table entry has type '%1%' which is not the expected type '%2%'", - {keyTuple, entryKeyType}); - if (tvs == nullptr) return entry; - ConstantTypeSubstitution cts(tvs, typeMap, this); - auto ks = cts.convert(keyset, getChildContext()); - if (::errorCount() > 0) return entry; - - if (ks != keyset) - entry = new IR::Entry(entry->srcInfo, entry->annotations, entry->isConst, entry->priority, - ks->to(), entry->action, entry->singleton); - - auto actionRef = entry->getAction(); - auto ale = validateActionInitializer(actionRef); - if (ale != nullptr) { - auto anno = ale->getAnnotation(IR::Annotation::defaultOnlyAnnotation); - if (anno != nullptr) { - typeError("%1%: Action marked with %2% used in table", entry, - IR::Annotation::defaultOnlyAnnotation); - return entry; - } - } - return entry; -} - -const IR::Node *TypeInference::postorder(IR::ListExpression *expression) { - if (done()) return expression; - bool constant = true; - auto components = new IR::Vector(); - for (auto c : expression->components) { - if (!isCompileTimeConstant(c)) constant = false; - auto type = getType(c); - if (type == nullptr) return expression; - components->push_back(type); - } - - auto tupleType = new IR::Type_List(expression->srcInfo, *components); - auto type = canonicalize(tupleType); - if (type == nullptr) return expression; - setType(getOriginal(), type); - setType(expression, type); - if (constant) { - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - } - return expression; -} - -const IR::Node *TypeInference::postorder(IR::Invalid *expression) { - if (done()) return expression; - auto unk = IR::Type_Unknown::get(); - setType(expression, unk); - setType(getOriginal(), unk); - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - return expression; -} - -const IR::Node *TypeInference::postorder(IR::InvalidHeader *expression) { - if (done()) return expression; - auto type = getTypeType(expression->headerType); - auto concreteType = type; - if (auto ts = concreteType->to()) concreteType = ts->substituted; - if (!concreteType->is()) { - typeError("%1%: invalid header expression has a non-header type `%2%`", expression, type); - return expression; - } - setType(expression, type); - setType(getOriginal(), type); - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - return expression; -} - -const IR::Node *TypeInference::postorder(IR::InvalidHeaderUnion *expression) { - if (done()) return expression; - auto type = getTypeType(expression->headerUnionType); - auto concreteType = type; - if (auto ts = concreteType->to()) concreteType = ts->substituted; - if (!concreteType->is()) { - typeError("%1%: does not have a header_union type `%2%`", expression, type); - return expression; - } - setType(expression, type); - setType(getOriginal(), type); - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - return expression; -} - -const IR::Node *TypeInference::postorder(IR::P4ListExpression *expression) { - if (done()) return expression; - bool constant = true; - auto elementType = getTypeType(expression->elementType); - auto vec = new IR::Vector(); - bool changed = false; - for (auto c : expression->components) { - if (!isCompileTimeConstant(c)) constant = false; - auto type = getType(c); - if (type == nullptr) return expression; - auto tvs = unify(expression, elementType, type, - "Vector element type '%1%' does not match expected type '%2%'", - {type, elementType}); - if (tvs == nullptr) return expression; - if (!tvs->isIdentity()) { - ConstantTypeSubstitution cts(tvs, typeMap, this); - auto converted = cts.convert(c, getChildContext()); - vec->push_back(converted); - changed = changed || converted != c; - } else { - vec->push_back(c); - } - } - - if (changed) - expression = new IR::P4ListExpression(expression->srcInfo, *vec, elementType->getP4Type()); - auto type = new IR::Type_P4List(expression->srcInfo, elementType); - setType(getOriginal(), type); - setType(expression, type); - if (constant) { - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - } - return expression; -} - -const IR::Node *TypeInference::postorder(IR::HeaderStackExpression *expression) { - if (done()) return expression; - bool constant = true; - auto stackType = getTypeType(expression->headerStackType); - if (auto st = stackType->to()) { - auto elementType = st->elementType; - auto vec = new IR::Vector(); - bool changed = false; - if (expression->size() != st->getSize()) { - typeError("%1%: number of initializers %2% has to match stack size %3%", expression, - expression->size(), st->getSize()); - return expression; - } - for (auto c : expression->components) { - if (!isCompileTimeConstant(c)) constant = false; - auto type = getType(c); - if (type == nullptr) return expression; - auto tvs = unify(expression, elementType, type, - "Stack element type '%1%' does not match expected type '%2%'", - {type, elementType}); - if (tvs == nullptr) return expression; - if (!tvs->isIdentity()) { - ConstantTypeSubstitution cts(tvs, typeMap, this); - auto converted = cts.convert(c, getChildContext()); - vec->push_back(converted); - changed = true; - } else { - vec->push_back(c); - } - if (changed) - expression = new IR::HeaderStackExpression(expression->srcInfo, *vec, stackType); - } - } else { - typeError("%1%: header stack expression has an incorrect type `%2%`", expression, - stackType); - return expression; - } - - setType(getOriginal(), stackType); - setType(expression, stackType); - if (constant) { - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - } - return expression; -} - -const IR::Node *TypeInference::postorder(IR::StructExpression *expression) { - if (done()) return expression; - bool constant = true; - auto components = new IR::IndexedVector(); - for (auto c : expression->components) { - if (!isCompileTimeConstant(c->expression)) constant = false; - auto type = getType(c->expression); - if (type == nullptr) return expression; - components->push_back(new IR::StructField(c->name, type)); - } - - // This is the type inferred by looking at the fields. - const IR::Type *structType = - new IR::Type_UnknownStruct(expression->srcInfo, "unknown struct", *components); - structType = canonicalize(structType); - - const IR::Expression *result = expression; - if (expression->structType != nullptr) { - // We know the exact type of the initializer - auto desired = getTypeType(expression->structType); - if (desired == nullptr) return expression; - auto tvs = unify(expression, desired, structType, - "Initializer type '%1%' does not match expected type '%2%'", - {structType, desired}); - if (tvs == nullptr) return expression; - if (!tvs->isIdentity()) { - ConstantTypeSubstitution cts(tvs, typeMap, this); - result = cts.convert(expression, getChildContext()); - } - structType = desired; - } - setType(getOriginal(), structType); - setType(expression, structType); - if (constant) { - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - } - return result; -} - -const IR::Node *TypeInference::postorder(IR::ArrayIndex *expression) { - if (done()) return expression; - auto ltype = getType(expression->left); - auto rtype = getType(expression->right); - if (ltype == nullptr || rtype == nullptr) return expression; - auto hst = ltype->to(); - - int index = -1; - if (auto cst = expression->right->to()) { - if (hst && checkArrays && !cst->fitsInt()) { - typeError("Index too large: %1%", cst); - return expression; - } - index = cst->asInt(); - if (hst && checkArrays && index < 0) { - typeError("%1%: Negative array index %2%", expression, cst); - return expression; - } - } - // if index is negative here it means it's not a constant - - if ((index < 0) && !rtype->is() && !rtype->is() && - !rtype->is()) { - typeError("Array index %1% must be an integer, but it has type %2%", expression->right, - rtype->toString()); - return expression; - } - - const IR::Type *type = nullptr; - if (hst) { - if (checkArrays && hst->sizeKnown()) { - int size = hst->getSize(); - if (index >= 0 && index >= size) { - typeError("Array index %1% larger or equal to array size %2%", expression->right, - hst->size); - return expression; - } - } - type = hst->elementType; - } else if (auto tup = ltype->to()) { - if (index < 0) { - typeError("Tuple index %1% must be constant", expression->right); - return expression; - } - if (static_cast(index) >= tup->getSize()) { - typeError("Tuple index %1% larger than tuple size %2%", expression->right, - tup->getSize()); - return expression; - } - type = tup->components.at(index); - if (isCompileTimeConstant(expression->left)) { - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - } - } else { - typeError("Indexing %1% applied to non-array and non-tuple type %2%", expression, - ltype->toString()); - return expression; - } - if (isLeftValue(expression->left)) { - setLeftValue(expression); - setLeftValue(getOriginal()); - } - setType(getOriginal(), type); - setType(expression, type); - return expression; -} - -const IR::Node *TypeInference::binaryBool(const IR::Operation_Binary *expression) { - if (done()) return expression; - auto ltype = getType(expression->left); - auto rtype = getType(expression->right); - if (ltype == nullptr || rtype == nullptr) return expression; - - if (!ltype->is() || !rtype->is()) { - typeError("%1%: not defined on %2% and %3%", expression, ltype->toString(), - rtype->toString()); - return expression; - } - setType(getOriginal(), IR::Type_Boolean::get()); - setType(expression, IR::Type_Boolean::get()); - if (isCompileTimeConstant(expression->left) && isCompileTimeConstant(expression->right)) { - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - } - return expression; -} - -const IR::Node *TypeInference::binaryArith(const IR::Operation_Binary *expression) { - if (done()) return expression; - auto ltype = getType(expression->left); - auto rtype = getType(expression->right); - if (ltype == nullptr || rtype == nullptr) return expression; - bool castLeft = false; - bool castRight = false; - - if (auto se = ltype->to()) { - ltype = getTypeType(se->type); - castLeft = true; - } - if (auto se = rtype->to()) { - rtype = getTypeType(se->type); - castRight = true; - } - BUG_CHECK(ltype && rtype, "Invalid Type_SerEnum/getTypeType"); - - const IR::Type_Bits *bl = ltype->to(); - const IR::Type_Bits *br = rtype->to(); - if (bl == nullptr && !ltype->is()) { - typeError("%1%: cannot be applied to expression '%2%' with type '%3%'", - expression->getStringOp(), expression->left, ltype->toString()); - return expression; - } else if (br == nullptr && !rtype->is()) { - typeError("%1%: cannot be applied to expression '%2%' with type '%3%'", - expression->getStringOp(), expression->right, rtype->toString()); - return expression; - } else if (ltype->is() && rtype->is()) { - auto t = IR::Type_InfInt::get(); - auto result = constantFold(expression); - setType(getOriginal(), t); - setCompileTimeConstant(result); - setCompileTimeConstant(getOriginal()); - return result; - } - - const IR::Type *resultType = ltype; - if (bl != nullptr && br != nullptr) { - if (bl->size != br->size) { - typeError("%1%: Cannot operate on values with different widths %2% and %3%", expression, - bl->size, br->size); - return expression; - } - if (bl->isSigned != br->isSigned) { - typeError("%1%: Cannot operate on values with different signs", expression); - return expression; - } - } - if ((bl == nullptr && br != nullptr) || castLeft) { - // must insert cast on the left - auto leftResultType = br; - if (castLeft && !br) leftResultType = bl; - auto e = expression->clone(); - e->left = new IR::Cast(e->left->srcInfo, leftResultType, e->left); - setType(e->left, leftResultType); - if (isCompileTimeConstant(expression->left)) { - e->left = constantFold(e->left); - setCompileTimeConstant(e->left); - } - expression = e; - resultType = leftResultType; - } - if ((bl != nullptr && br == nullptr) || castRight) { - auto e = expression->clone(); - auto rightResultType = bl; - if (castRight && !bl) rightResultType = br; - e->right = new IR::Cast(e->right->srcInfo, rightResultType, e->right); - setType(e->right, rightResultType); - if (isCompileTimeConstant(expression->right)) { - e->right = constantFold(e->right); - setCompileTimeConstant(e->right); - } - expression = e; - resultType = rightResultType; - } - - setType(getOriginal(), resultType); - setType(expression, resultType); - if (isCompileTimeConstant(expression->left) && isCompileTimeConstant(expression->right)) { - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - } - return expression; -} - -const IR::Node *TypeInference::unsBinaryArith(const IR::Operation_Binary *expression) { - if (done()) return expression; - auto ltype = getType(expression->left); - auto rtype = getType(expression->right); - if (ltype == nullptr || rtype == nullptr) return expression; - - if (auto se = ltype->to()) ltype = getTypeType(se->type); - if (auto se = rtype->to()) rtype = getTypeType(se->type); - - const IR::Type_Bits *bl = ltype->to(); - if (bl != nullptr && bl->isSigned) { - typeError("%1%: Cannot operate on signed values", expression); - return expression; - } - const IR::Type_Bits *br = rtype->to(); - if (br != nullptr && br->isSigned) { - typeError("%1%: Cannot operate on signed values", expression); - return expression; - } - - auto cleft = expression->left->to(); - if (cleft != nullptr) { - if (cleft->value < 0) { - typeError("%1%: not defined on negative numbers", expression); - return expression; - } - } - auto cright = expression->right->to(); - if (cright != nullptr) { - if (cright->value < 0) { - typeError("%1%: not defined on negative numbers", expression); - return expression; - } - } - - if (isCompileTimeConstant(expression->left) && isCompileTimeConstant(expression->right)) { - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - } - return binaryArith(expression); -} - -const IR::Node *TypeInference::shift(const IR::Operation_Binary *expression) { - if (done()) return expression; - auto ltype = getType(expression->left); - auto rtype = getType(expression->right); - if (ltype == nullptr || rtype == nullptr) return expression; - - if (auto se = ltype->to()) ltype = getTypeType(se->type); - if (ltype == nullptr) { - // getTypeType should have already taken care of the error message - return expression; - } - auto lt = ltype->to(); - if (auto cst = expression->right->to()) { - if (!cst->fitsInt()) { - typeError("Shift amount too large: %1%", cst); - return expression; - } - int shift = cst->asInt(); - if (shift < 0) { - typeError("%1%: Negative shift amount %2%", expression, cst); - return expression; - } - if (lt != nullptr && shift >= lt->size) - warn(ErrorType::WARN_OVERFLOW, "%1%: shifting value with %2% bits by %3%", expression, - lt->size, shift); - // If the amount is signed but positive, make it unsigned - if (auto bt = rtype->to()) { - if (bt->isSigned) { - rtype = IR::Type_Bits::get(rtype->srcInfo, bt->width_bits(), false); - auto amt = new IR::Constant(cst->srcInfo, rtype, cst->value, cst->base); - if (expression->is()) { - expression = new IR::Shl(expression->srcInfo, expression->left, amt); - } else { - expression = new IR::Shr(expression->srcInfo, expression->left, amt); - } - setCompileTimeConstant(expression->right); - setType(expression->right, rtype); - } - } - } - - if (rtype->is() && rtype->to()->isSigned) { - typeError("%1%: Shift amount must be an unsigned number", expression->right); - return expression; - } - - if (!lt && !ltype->is()) { - typeError("%1% left operand of shift must be a numeric type, not %2%", expression, - ltype->toString()); - return expression; - } - - if (ltype->is() && !rtype->is() && - !isCompileTimeConstant(expression->right)) { - typeError( - "%1%: shift result type is arbitrary-precision int, but right operand is not constant; " - "width of left operand of shift needs to be specified or both operands need to be " - "constant", - expression); - return expression; - } - - setType(expression, ltype); - setType(getOriginal(), ltype); - if (isCompileTimeConstant(expression->left) && isCompileTimeConstant(expression->right)) { - auto result = constantFold(expression); - setCompileTimeConstant(result); - setCompileTimeConstant(getOriginal()); - return result; - } - return expression; -} - -// Handle .. and &&& -const IR::Node *TypeInference::typeSet(const IR::Operation_Binary *expression) { - if (done()) return expression; - auto ltype = getType(expression->left); - auto rtype = getType(expression->right); - if (ltype == nullptr || rtype == nullptr) return expression; - - auto leftType = ltype; // save original type - if (auto se = ltype->to()) ltype = getTypeType(se->type); - if (auto se = rtype->to()) rtype = getTypeType(se->type); - BUG_CHECK(ltype && rtype, "Invalid Type_SerEnum/getTypeType"); - - // The following section is very similar to "binaryArith()" above - const IR::Type_Bits *bl = ltype->to(); - const IR::Type_Bits *br = rtype->to(); - if (bl == nullptr && !ltype->is()) { - typeError("%1%: cannot be applied to expression '%2%' with type '%3%'", - expression->getStringOp(), expression->left, ltype->toString()); - return expression; - } else if (br == nullptr && !rtype->is()) { - typeError("%1%: cannot be applied to expression '%2%' with type '%3%'", - expression->getStringOp(), expression->right, rtype->toString()); - return expression; - } - - const IR::Type *sameType = leftType; - if (bl != nullptr && br != nullptr) { - if (!typeMap->equivalent(bl, br)) { - typeError("%1%: Cannot operate on values with different types %2% and %3%", expression, - bl->toString(), br->toString()); - return expression; - } - } else if (bl == nullptr && br != nullptr) { - auto e = expression->clone(); - e->left = new IR::Cast(e->left->srcInfo, rtype, e->left); - setCompileTimeConstant(e->left); - expression = e; - sameType = rtype; - setType(e->left, sameType); - } else if (bl != nullptr && br == nullptr) { - auto e = expression->clone(); - e->right = new IR::Cast(e->right->srcInfo, ltype, e->right); - setCompileTimeConstant(e->right); - expression = e; - setType(e->right, ltype); - sameType = leftType; // Not ltype: SerEnum &&& Bit is Set - } else { - // both are InfInt: use same exact type for both sides, so it is properly - // set after unification - // FIXME -- the below is obviously wrong and just serves to tweak when precisely - // the type will be inferred -- papering over bugs elsewhere in typechecking, - // avoiding the BUG_CHECK(!readOnly... in end_apply/apply_visitor above. - // (maybe just need learner->readOnly = false in TypeInference::learn above?) - auto r = expression->right->clone(); - auto e = expression->clone(); - if (isCompileTimeConstant(expression->right)) setCompileTimeConstant(r); - e->right = r; - expression = e; - setType(r, sameType); - } - - auto resultType = new IR::Type_Set(sameType->srcInfo, sameType); - typeMap->setType(expression, resultType); - typeMap->setType(getOriginal(), resultType); - - if (isCompileTimeConstant(expression->left) && isCompileTimeConstant(expression->right)) { - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - } - - return expression; -} - -const IR::Node *TypeInference::postorder(IR::LNot *expression) { - if (done()) return expression; - auto type = getType(expression->expr); - if (type == nullptr) return expression; - if (!(*type == *IR::Type_Boolean::get())) { - typeError("Cannot apply %1% to value %2% of type %3%", expression->getStringOp(), - expression->expr, type->toString()); - } else { - setType(expression, IR::Type_Boolean::get()); - setType(getOriginal(), IR::Type_Boolean::get()); - } - if (isCompileTimeConstant(expression->expr)) { - auto result = constantFold(expression); - setCompileTimeConstant(result); - setCompileTimeConstant(getOriginal()); - return result; - } - return expression; -} - -const IR::Node *TypeInference::postorder(IR::Neg *expression) { - if (done()) return expression; - auto type = getType(expression->expr); - if (type == nullptr) return expression; - - if (auto se = type->to()) type = getTypeType(se->type); - BUG_CHECK(type, "Invalid Type_SerEnum/getTypeType"); - - if (type->is()) { - setType(getOriginal(), type); - setType(expression, type); - } else if (type->is()) { - setType(getOriginal(), type); - setType(expression, type); - } else { - typeError("Cannot apply %1% to value %2% of type %3%", expression->getStringOp(), - expression->expr, type->toString()); - } - if (isCompileTimeConstant(expression->expr)) { - auto result = constantFold(expression); - setCompileTimeConstant(result); - setCompileTimeConstant(getOriginal()); - return result; - } - return expression; -} - -const IR::Node *TypeInference::postorder(IR::UPlus *expression) { - if (done()) return expression; - auto type = getType(expression->expr); - if (type == nullptr) return expression; - - if (auto se = type->to()) type = getTypeType(se->type); - BUG_CHECK(type, "Invalid Type_SerEnum/getTypeType"); - - if (type->is()) { - setType(getOriginal(), type); - setType(expression, type); - } else if (type->is()) { - setType(getOriginal(), type); - setType(expression, type); - } else { - typeError("Cannot apply %1% to value %2% of type %3%", expression->getStringOp(), - expression->expr, type->toString()); - } - if (isCompileTimeConstant(expression->expr)) { - auto result = constantFold(expression); - setCompileTimeConstant(result); - setCompileTimeConstant(getOriginal()); - return result; - } - return expression; -} - -const IR::Node *TypeInference::postorder(IR::Cmpl *expression) { - if (done()) return expression; - auto type = getType(expression->expr); - if (type == nullptr) return expression; - - if (auto se = type->to()) type = getTypeType(se->type); - BUG_CHECK(type, "Invalid Type_SerEnum/getTypeType"); - - if (type->is()) { - typeError("'%1%' cannot be applied to an operand with an unknown width"); - } else if (type->is()) { - setType(getOriginal(), type); - setType(expression, type); - } else { - typeError("Cannot apply operation '%1%' to expression '%2%' with type '%3%'", - expression->getStringOp(), expression->expr, type->toString()); - } - if (isCompileTimeConstant(expression->expr)) { - auto result = constantFold(expression); - setCompileTimeConstant(result); - setCompileTimeConstant(getOriginal()); - return result; - } - return expression; -} - -const IR::Node *TypeInference::postorder(IR::Cast *expression) { - if (done()) return expression; - const IR::Type *sourceType = getType(expression->expr); - const IR::Type *castType = getTypeType(expression->destType); - if (sourceType == nullptr || castType == nullptr) return expression; - - auto concreteType = castType; - if (auto tsc = castType->to()) concreteType = tsc->substituted; - if (auto st = concreteType->to()) { - if (auto se = expression->expr->to()) { - // Interpret (S) { kvpairs } as a struct initializer expression - // instead of a cast to a struct. - if (se->type == nullptr || se->type->is() || - se->type->is()) { - auto type = castType->getP4Type(); - setType(type, new IR::Type_Type(st)); - auto sie = new IR::StructExpression(se->srcInfo, type, se->components); - auto result = postorder(sie); // may insert casts - setType(result, st); - if (isCompileTimeConstant(se)) { - setCompileTimeConstant(result->to()); - setCompileTimeConstant(getOriginal()); - } - return result; - } else { - typeError("%1%: cast not supported", expression->destType); - return expression; - } - } else if (expression->expr->is()) { - auto result = assignment(expression, st, expression->expr); - return result; - } else if (auto ih = expression->expr->to()) { - auto type = castType->getP4Type(); - auto concreteCastType = castType; - if (auto ts = castType->to()) - concreteCastType = ts->substituted; - if (concreteCastType->is()) { - setType(type, new IR::Type_Type(castType)); - auto result = new IR::InvalidHeader(ih->srcInfo, type, type); - setType(result, castType); - return result; - } else if (concreteCastType->is()) { - setType(type, new IR::Type_Type(castType)); - auto result = new IR::InvalidHeaderUnion(ih->srcInfo, type, type); - setType(result, castType); - return result; - } else { - typeError("%1%: 'invalid' expression type `%2%` must be a header or header union", - expression, castType); - return expression; - } - } - } - if (auto lt = concreteType->to()) { - auto listElementType = lt->elementType; - if (auto le = expression->expr->to()) { - IR::Vector vec; - bool isConstant = true; - for (size_t i = 0; i < le->size(); i++) { - auto compI = le->components.at(i); - auto src = assignment(expression, listElementType, compI); - if (!isCompileTimeConstant(src)) isConstant = false; - vec.push_back(src); - } - auto vecType = castType->getP4Type(); - setType(vecType, new IR::Type_Type(lt)); - auto result = new IR::P4ListExpression(le->srcInfo, vec, listElementType->getP4Type()); - setType(result, lt); - if (isConstant) { - setCompileTimeConstant(result); - setCompileTimeConstant(getOriginal()); - } - return result; - } else { - typeError("%1%: casts to list not supported", expression); - return expression; - } - } - if (concreteType->is()) { - if (expression->expr->is()) { - auto result = assignment(expression, concreteType, expression->expr); - return result; - } else { - typeError("%1%: casts to header stack not supported", expression); - return expression; - } - } - - if (!castType->is() && !castType->is() && - !castType->is() && !castType->is() && - !castType->is() && !castType->is()) { - typeError("%1%: cast not supported", expression->destType); - return expression; - } - - if (!canCastBetween(castType, sourceType)) { - // This cast is not legal directly, but let's try to see whether - // performing a substitution can help. This will allow the use - // of constants on the RHS. - const IR::Type *destType = castType; - while (destType->is()) - destType = getTypeType(destType->to()->type); - - auto tvs = unify(expression, destType, sourceType, "Cannot cast from '%1%' to '%2%'", - {sourceType, castType}); - if (tvs == nullptr) return expression; - const IR::Expression *rhs = expression->expr; - if (!tvs->isIdentity()) { - ConstantTypeSubstitution cts(tvs, typeMap, this); - rhs = cts.convert(expression->expr, getChildContext()); // sets type - } - if (rhs != expression->expr) { - // if we are here we have performed a substitution on the rhs - expression = new IR::Cast(expression->srcInfo, expression->destType, rhs); - sourceType = getTypeType(expression->destType); - } - if (!canCastBetween(castType, sourceType)) - typeError("%1%: Illegal cast from %2% to %3%", expression, sourceType->toString(), - castType->toString()); - } - setType(expression, castType); - setType(getOriginal(), castType); - if (isCompileTimeConstant(expression->expr)) { - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - } - return expression; -} - -const IR::Node *TypeInference::postorder(IR::PathExpression *expression) { - if (done()) return expression; - auto decl = getDeclaration(expression->path, !errorOnNullDecls); - if (errorOnNullDecls && decl == nullptr) { - typeError("%1%: Cannot resolve declaration", expression); - return expression; - } - const IR::Type *type = nullptr; - if (auto tbl = decl->to()) { - if (auto current = findContext()) { - if (current->name == tbl->name) { - typeError("%1%: Cannot refer to the containing table %2%", expression, tbl); - return expression; - } - } - } else if (decl->is() || decl->is()) { - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - } - // For MatchKind and Errors all ids have a type that has been set - // while processing Type_Error or Declaration_Matchkind - auto declType = typeMap->getType(decl->getNode()); - if (decl->is() && declType && - (declType->is() || declType->is())) { - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - } - - if (decl->is()) { - type = IR::Type_State::get(); - } else if (decl->is()) { - setLeftValue(expression); - setLeftValue(getOriginal()); - } else if (decl->is()) { - auto paramDecl = decl->to(); - if (paramDecl->direction == IR::Direction::InOut || - paramDecl->direction == IR::Direction::Out) { - setLeftValue(expression); - setLeftValue(getOriginal()); - } else if (paramDecl->direction == IR::Direction::None) { - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - } - } else if (decl->is() || decl->is()) { - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - } else if (decl->is() || decl->is()) { - type = getType(decl->getNode()); - // Each method invocation uses fresh type variables - if (type != nullptr) - // may be nullptr because typechecking may have failed - type = cloneWithFreshTypeVariables(type->to()); - } else if (decl->is()) { - typeError("%1%: Type cannot be used here, expecting an expression.", expression); - return expression; - } - - if (type == nullptr) { - type = getType(decl->getNode()); - if (type == nullptr) return expression; - } - - setType(getOriginal(), type); - setType(expression, type); - return expression; -} - -const IR::Node *TypeInference::postorder(IR::Slice *expression) { - if (done()) return expression; - const IR::Type *type = getType(expression->e0); - if (type == nullptr) return expression; - - if (auto se = type->to()) type = getTypeType(se->type); - - if (!type->is()) { - typeError("%1%: bit extraction only defined for bit<> types", expression); - return expression; - } - - auto e1type = getType(expression->e1); - if (e1type && e1type->is()) { - auto ei = EnumInstance::resolve(expression->e1, typeMap); - CHECK_NULL(ei); - auto sei = ei->to(); - if (sei == nullptr) { - typeError("%1%: slice bit index values must be constants", expression->e1); - return expression; - } - expression->e1 = sei->value; - } - auto e2type = getType(expression->e2); - if (e2type && e2type->is()) { - auto ei = EnumInstance::resolve(expression->e2, typeMap); - CHECK_NULL(ei); - auto sei = ei->to(); - if (sei == nullptr) { - typeError("%1%: slice bit index values must be constants", expression->e2); - return expression; - } - expression->e2 = sei->value; - } - - auto bst = type->to(); - if (!expression->e1->is()) { - typeError("%1%: slice bit index values must be constants", expression->e1); - return expression; - } - if (!expression->e2->is()) { - typeError("%1%: slice bit index values must be constants", expression->e2); - return expression; - } - - auto msb = expression->e1->checkedTo(); - auto lsb = expression->e2->checkedTo(); - if (!msb->fitsInt()) { - typeError("%1%: bit index too large", msb); - return expression; - } - if (!lsb->fitsInt()) { - typeError("%1%: bit index too large", lsb); - return expression; - } - int m = msb->asInt(); - int l = lsb->asInt(); - if (m < 0) { - typeError("%1%: negative bit index %2%", expression, msb); - return expression; - } - if (l < 0) { - typeError("%1%: negative bit index %2%", expression, lsb); - return expression; - } - if (m >= bst->size) { - typeError("Bit index %1% greater than width %2%", msb, bst->size); - return expression; - } - if (l >= bst->size) { - typeError("Bit index %1% greater than width %2%", msb, bst->size); - return expression; - } - if (l > m) { - typeError("LSB index %1% greater than MSB index %2%", lsb, msb); - return expression; - } - - const IR::Type *resultType = IR::Type_Bits::get(bst->srcInfo, m - l + 1, false); - resultType = canonicalize(resultType); - if (resultType == nullptr) return expression; - setType(getOriginal(), resultType); - setType(expression, resultType); - if (isLeftValue(expression->e0)) { - setLeftValue(expression); - setLeftValue(getOriginal()); - } - if (isCompileTimeConstant(expression->e0)) { - auto result = constantFold(expression); - setCompileTimeConstant(result); - setCompileTimeConstant(getOriginal()); - return result; - } - return expression; -} - -const IR::Node *TypeInference::postorder(IR::Dots *expression) { - if (done()) return expression; - setType(expression, IR::Type_Any::get()); - setType(getOriginal(), IR::Type_Any::get()); - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - return expression; -} - -const IR::Node *TypeInference::postorder(IR::Mux *expression) { - if (done()) return expression; - const IR::Type *firstType = getType(expression->e0); - const IR::Type *secondType = getType(expression->e1); - const IR::Type *thirdType = getType(expression->e2); - if (firstType == nullptr || secondType == nullptr || thirdType == nullptr) return expression; - - if (!firstType->is()) { - typeError("Selector of %1% must be bool, not %2%", expression->getStringOp(), - firstType->toString()); - return expression; - } - - if (secondType->is() && thirdType->is()) { - typeError("Width must be specified for at least one of %1% or %2%", expression->e1, - expression->e2); - return expression; - } - auto tvs = unify(expression, secondType, thirdType, - "The expressions in a ?: conditional have different types '%1%' and '%2%'", - {secondType, thirdType}); - if (tvs != nullptr) { - if (!tvs->isIdentity()) { - ConstantTypeSubstitution cts(tvs, typeMap, this); - auto e1 = cts.convert(expression->e1, getChildContext()); - auto e2 = cts.convert(expression->e2, getChildContext()); - if (::errorCount() > 0) return expression; - expression->e1 = e1; - expression->e2 = e2; - secondType = typeMap->getType(e1); - } - setType(expression, secondType); - setType(getOriginal(), secondType); - if (isCompileTimeConstant(expression->e0) && isCompileTimeConstant(expression->e1) && - isCompileTimeConstant(expression->e2)) { - auto result = constantFold(expression); - setCompileTimeConstant(result); - setCompileTimeConstant(getOriginal()); - return result; - } - } - return expression; -} - -const IR::Node *TypeInference::postorder(IR::TypeNameExpression *expression) { - if (done()) return expression; - auto type = getType(expression->typeName); - if (type == nullptr) return expression; - setType(getOriginal(), type); - setType(expression, type); - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - return expression; -} - -const IR::Node *TypeInference::postorder(IR::Member *expression) { - if (done()) return expression; - auto type = getType(expression->expr); - if (type == nullptr) return expression; - - cstring member = expression->member.name; - if (auto ts = type->to()) type = ts->substituted; - - if (auto *ext = type->to()) { - auto call = findContext(); - if (call == nullptr) { - typeError("%1%: Methods can only be called", expression); - return expression; - } - auto method = ext->lookupMethod(expression->member, call->arguments); - if (method == nullptr) { - typeError("%1%: extern %2% does not have method matching this call", expression, - ext->name); - return expression; - } - - const IR::Type *methodType = getType(method); - if (methodType == nullptr) return expression; - // Each method invocation uses fresh type variables - methodType = cloneWithFreshTypeVariables(methodType->to()); - - setType(getOriginal(), methodType); - setType(expression, methodType); - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - return expression; - } - - bool inMethod = getParent() != nullptr; - // Built-in methods - if (inMethod && (member == IR::Type::minSizeInBits || member == IR::Type::minSizeInBytes || - member == IR::Type::maxSizeInBits || member == IR::Type::maxSizeInBytes)) { - auto type = new IR::Type_Method(IR::Type_InfInt::get(), new IR::ParameterList(), member); - auto ctype = canonicalize(type); - if (ctype == nullptr) return expression; - setType(getOriginal(), ctype); - setType(expression, ctype); - return expression; - } - - if (type->is()) { - std::string typeStr = "structure "; - if (type->is() || type->is()) { - typeStr = ""; - if (inMethod && (member == IR::Type_Header::isValid)) { - // Built-in method - auto type = - new IR::Type_Method(IR::Type_Boolean::get(), new IR::ParameterList(), member); - auto ctype = canonicalize(type); - if (ctype == nullptr) return expression; - setType(getOriginal(), ctype); - setType(expression, ctype); - return expression; - } - } - if (type->is()) { - if (inMethod && - (member == IR::Type_Header::setValid || member == IR::Type_Header::setInvalid)) { - if (!isLeftValue(expression->expr)) - typeError("%1%: must be applied to a left-value", expression); - // Built-in method - auto type = - new IR::Type_Method(IR::Type_Void::get(), new IR::ParameterList, member); - auto ctype = canonicalize(type); - if (ctype == nullptr) return expression; - setType(getOriginal(), ctype); - setType(expression, ctype); - return expression; - } - } - - auto stb = type->to(); - auto field = stb->getField(member); - if (field == nullptr) { - typeError("Field %1% is not a member of %2%%3%", expression->member, typeStr, stb); - return expression; - } - - auto fieldType = getTypeType(field->type); - if (fieldType == nullptr) return expression; - if (fieldType->is() && !getParent()) { - typeError("%1%: only allowed in switch statements", expression); - return expression; - } - setType(getOriginal(), fieldType); - setType(expression, fieldType); - if (isLeftValue(expression->expr)) { - setLeftValue(expression); - setLeftValue(getOriginal()); - } else { - LOG2("No left value " << expression->expr); - } - if (isCompileTimeConstant(expression->expr)) { - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - } - return expression; - } - - if (auto *apply = type->to(); apply && member == IR::IApply::applyMethodName) { - auto *methodType = apply->getApplyMethodType(); - auto *canon = canonicalize(methodType); - if (!canon) return expression; - methodType = canon->to(); - if (methodType == nullptr) return expression; - learn(methodType, this, getChildContext()); - setType(getOriginal(), methodType); - setType(expression, methodType); - return expression; - } - - if (auto *stack = type->to()) { - auto parser = findContext(); - if (member == IR::Type_Stack::next || member == IR::Type_Stack::last) { - if (parser == nullptr) { - typeError("%1%: 'last', and 'next' for stacks can only be used in a parser", - expression); - return expression; - } - setType(getOriginal(), stack->elementType); - setType(expression, stack->elementType); - if (isLeftValue(expression->expr) && member == IR::Type_Stack::next) { - setLeftValue(expression); - setLeftValue(getOriginal()); - } - return expression; - } else if (member == IR::Type_Stack::arraySize) { - setType(getOriginal(), IR::Type_Bits::get(32)); - setType(expression, IR::Type_Bits::get(32)); - return expression; - } else if (member == IR::Type_Stack::lastIndex) { - if (parser == nullptr) { - typeError("%1%: 'lastIndex' for stacks can only be used in a parser", expression); - return expression; - } - setType(getOriginal(), IR::Type_Bits::get(32, false)); - setType(expression, IR::Type_Bits::get(32, false)); - return expression; - } else if (member == IR::Type_Stack::push_front || member == IR::Type_Stack::pop_front) { - if (parser != nullptr) - typeError("%1%: '%2%' and '%3%' for stacks cannot be used in a parser", expression, - IR::Type_Stack::push_front, IR::Type_Stack::pop_front); - if (!isLeftValue(expression->expr)) - typeError("%1%: must be applied to a left-value", expression); - auto params = new IR::IndexedVector(); - auto param = new IR::Parameter(IR::ID("count"_cs, nullptr), IR::Direction::None, - IR::Type_InfInt::get()); - auto tt = new IR::Type_Type(param->type); - setType(param->type, tt); - setType(param, param->type); - params->push_back(param); - auto type = - new IR::Type_Method(IR::Type_Void::get(), new IR::ParameterList(*params), member); - auto canon = canonicalize(type); - if (canon == nullptr) return expression; - setType(getOriginal(), canon); - setType(expression, canon); - return expression; - } - } - - if (auto *tt = type->to()) { - auto base = tt->type; - if (base->is() || base->is() || - base->is()) { - if (isCompileTimeConstant(expression->expr)) { - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - } - auto fbase = base->to(); - if (auto decl = fbase->getDeclByName(member)) { - if (auto ftype = getType(decl->getNode())) { - setType(getOriginal(), ftype); - setType(expression, ftype); - } - } else { - typeError("%1%: Invalid enum tag", expression); - setType(getOriginal(), type); - setType(expression, type); - } - return expression; - } - } - - typeError("Cannot extract field %1% from %2% which has type %3%", expression->member, - expression->expr, type->toString()); - // unreachable - return expression; -} - -// If inActionList this call is made in the "action" property of a table -const IR::Expression *TypeInference::actionCall(bool inActionList, - const IR::MethodCallExpression *actionCall) { - // If a is an action with signature _(arg1, arg2, arg3) - // Then the call a(arg1, arg2) is also an - // action, with signature _(arg3) - LOG2("Processing action " << dbp(actionCall)); - - if (findContext()) { - typeError("%1%: Action calls are not allowed within parsers", actionCall); - return actionCall; - } - auto method = actionCall->method; - auto methodType = getType(method); - if (!methodType) return actionCall; // error emitted in getType - auto baseType = methodType->to(); - if (!baseType) { - typeError("%1%: must be an action", method); - return actionCall; - } - LOG2("Action type " << baseType); - BUG_CHECK(method->is(), "%1%: unexpected call", method); - BUG_CHECK(baseType->returnType == nullptr, "%1%: action with return type?", - baseType->returnType); - if (!baseType->typeParameters->empty()) { - typeError("%1%: Actions cannot be generic", baseType->typeParameters); - return actionCall; - } - if (!actionCall->typeArguments->empty()) { - typeError("%1%: Cannot supply type parameters for an action invocation", - actionCall->typeArguments); - return actionCall; - } - - bool inTable = findContext() != nullptr; - - TypeConstraints constraints(typeMap->getSubstitutions(), typeMap); - auto params = new IR::ParameterList; - - // keep track of parameters that have not been matched yet - absl::flat_hash_map left; - for (auto p : baseType->parameters->parameters) left.emplace(p->name, p); - - auto paramIt = baseType->parameters->parameters.begin(); - auto newArgs = new IR::Vector(); - bool changed = false; - for (auto arg : *actionCall->arguments) { - cstring argName = arg->name.name; - bool named = !argName.isNullOrEmpty(); - const IR::Parameter *param; - auto newExpr = arg->expression; - - if (named) { - param = baseType->parameters->getParameter(argName); - if (param == nullptr) { - typeError("%1%: No parameter named %2%", baseType->parameters, arg->name); - return actionCall; - } - } else { - if (paramIt == baseType->parameters->parameters.end()) { - typeError("%1%: Too many arguments for action", actionCall); - return actionCall; - } - param = *paramIt; - } - - LOG2("Action parameter " << dbp(param)); - if (!left.erase(param->name)) { - // This should have been checked by the CheckNamedArgs pass. - BUG("%1%: Duplicate argument name?", param->name); - } - - auto paramType = getType(param); - auto argType = getType(arg); - if (paramType == nullptr || argType == nullptr) - // type checking failed before - return actionCall; - constraints.addImplicitCastConstraint(actionCall, paramType, argType); - if (param->direction == IR::Direction::None) { - if (inActionList) { - typeError("%1%: parameter %2% cannot be bound: it is set by the control plane", arg, - param); - } else if (inTable) { - // For actions None parameters are treated as IN - // parameters when the action is called directly. We - // don't require them to be bound to a compile-time - // constant. But if the action is instantiated in a - // table (as default_action or entries), then the - // arguments do have to be compile-time constants. - if (!isCompileTimeConstant(arg->expression)) - typeError("%1%: action argument must be a compile-time constant", - arg->expression); - } - // This is like an assignment; may make additional conversions. - newExpr = assignment(arg, param->type, arg->expression); - if (readOnly) { - // FIXME -- if we're in readonly mode, we should not have introduced any mods - // here, but there's a bug in the DPDK backend where it generates a ListExpression - // that would be converted to a StructExpression, and other problems where it - // can't deal with that StructExpressions, so we hack to avoid breaking those tests - newExpr = arg->expression; - } - } else if (param->direction == IR::Direction::Out || - param->direction == IR::Direction::InOut) { - if (!isLeftValue(arg->expression)) - typeError("%1%: must be a left-value", arg->expression); - } else { - // This is like an assignment; may make additional conversions. - newExpr = assignment(arg, param->type, arg->expression); - } - if (::errorCount() > 0) return actionCall; - if (newExpr != arg->expression) { - LOG2("Changing action argument to " << newExpr); - changed = true; - newArgs->push_back(new IR::Argument(arg->srcInfo, arg->name, newExpr)); - } else { - newArgs->push_back(arg); - } - if (!named) ++paramIt; - } - if (changed) - actionCall = - new IR::MethodCallExpression(actionCall->srcInfo, actionCall->type, actionCall->method, - actionCall->typeArguments, newArgs); - - // Check remaining parameters: they must be all non-directional - bool error = false; - for (auto p : left) { - if (p.second->direction != IR::Direction::None && p.second->defaultValue == nullptr) { - typeError("%1%: Parameter %2% must be bound", actionCall, p.second); - error = true; - } - } - if (error) return actionCall; - - auto resultType = new IR::Type_Action(baseType->srcInfo, baseType->typeParameters, params); - - setType(getOriginal(), resultType); - setType(actionCall, resultType); - auto tvs = constraints.solve(); - if (tvs == nullptr || errorCount() > 0) return actionCall; - addSubstitutions(tvs); - - ConstantTypeSubstitution cts(tvs, typeMap, this); - actionCall = cts.convert(actionCall, getChildContext()) - ->to(); // cast arguments - if (::errorCount() > 0) return actionCall; - - LOG2("Converted action " << actionCall); - setType(actionCall, resultType); - return actionCall; -} - -bool hasVarbitsOrUnions(const TypeMap *typeMap, const IR::Type *type) { - // called for a canonical type - if (type->is() || type->is()) { - return true; - } else if (auto ht = type->to()) { - const IR::StructField *varbit = nullptr; - for (auto f : ht->fields) { - auto ftype = typeMap->getType(f); - if (ftype == nullptr) continue; - if (ftype->is()) { - if (varbit == nullptr) { - varbit = f; - } else { - typeError("%1% and %2%: multiple varbit fields in a header", varbit, f); - return type; - } - } - } - return varbit != nullptr; - } else if (auto at = type->to()) { - return hasVarbitsOrUnions(typeMap, at->elementType); - } else if (auto tpl = type->to()) { - for (auto f : tpl->components) { - if (hasVarbitsOrUnions(typeMap, f)) return true; - } - } - return false; -} - -bool TypeInference::onlyBitsOrBitStructs(const IR::Type *type) const { - // called for a canonical type - if (type->is() || type->is() || type->is()) { - return true; - } else if (auto ht = type->to()) { - for (auto f : ht->fields) { - auto ftype = typeMap->getType(f); - BUG_CHECK((ftype != nullptr), - "onlyBitsOrBitStructs check could not find type " - "for %1%", - f); - if (!onlyBitsOrBitStructs(ftype)) return false; - } - return true; - } - return false; -} - -const IR::Node *TypeInference::postorder(IR::MethodCallStatement *mcs) { - // Remove mcs if child methodCall resolves to a compile-time constant. - return !mcs->methodCall ? nullptr : mcs; -} - -const IR::Node *TypeInference::postorder(IR::MethodCallExpression *expression) { - if (done()) return expression; - LOG2("Solving method call " << dbp(expression)); - auto methodType = getType(expression->method); - if (methodType == nullptr) return expression; - auto methodBaseType = methodType->to(); - if (methodBaseType == nullptr) { - typeError("%1% is not a method", expression); - return expression; - } - - // Handle differently methods and actions: action invocations return actions - // with different signatures - if (methodType->is()) { - if (findContext()) { - typeError("%1%: Functions cannot call actions", expression); - return expression; - } - bool inActionsList = false; - auto prop = findContext(); - if (prop != nullptr && prop->name == IR::TableProperties::actionsPropertyName) - inActionsList = true; - return actionCall(inActionsList, expression); - } else { - // Constant-fold constant expressions - if (auto mem = expression->method->to()) { - auto type = typeMap->getType(mem->expr, true); - if (((mem->member == IR::Type::minSizeInBits || - mem->member == IR::Type::minSizeInBytes || - mem->member == IR::Type::maxSizeInBits || - mem->member == IR::Type::maxSizeInBytes)) && - !type->is() && expression->typeArguments->size() == 0 && - expression->arguments->size() == 0) { - auto max = mem->member.name.startsWith("max"); - int w = typeMap->widthBits(type, expression, max); - LOG3("Folding " << mem << " to " << w); - if (w < 0) return expression; - if (mem->member.name.endsWith("Bytes")) w = ROUNDUP(w, 8); - if (getParent()) return nullptr; - auto result = new IR::Constant(expression->srcInfo, w); - auto tt = new IR::Type_Type(result->type); - setType(result->type, tt); - setType(result, result->type); - setCompileTimeConstant(result); - return result; - } - if (mem->member == IR::Type_Header::isValid && type->is()) { - const IR::BoolLiteral *lit = nullptr; - if (mem->expr->is()) - lit = new IR::BoolLiteral(expression->srcInfo, false); - if (mem->expr->is()) - lit = new IR::BoolLiteral(expression->srcInfo, false); - if (mem->expr->is()) - lit = new IR::BoolLiteral(expression->srcInfo, true); - if (lit) { - LOG3("Folding " << mem << " to " << lit); - if (getParent()) return nullptr; - setType(lit, IR::Type_Boolean::get()); - setCompileTimeConstant(lit); - return lit; - } - } - } - - if (getContext()->node->is()) { - typeError("%1% is not invoking an action", expression); - return expression; - } - - // We build a type for the callExpression and unify it with the method expression - // Allocate a fresh variable for the return type; it will be hopefully bound in the process. - auto rettype = new IR::Type_Var(IR::ID(nameGen->newName("R"), ""_cs)); - auto args = new IR::Vector(); - bool constArgs = true; - for (auto aarg : *expression->arguments) { - auto arg = aarg->expression; - auto argType = getType(arg); - if (argType == nullptr) return expression; - auto argInfo = new IR::ArgumentInfo(arg->srcInfo, isLeftValue(arg), - isCompileTimeConstant(arg), argType, aarg); - args->push_back(argInfo); - constArgs = constArgs && isCompileTimeConstant(arg); - } - auto typeArgs = new IR::Vector(); - for (auto ta : *expression->typeArguments) { - auto taType = getTypeType(ta); - if (taType == nullptr) return expression; - typeArgs->push_back(taType); - } - auto callType = new IR::Type_MethodCall(expression->srcInfo, typeArgs, rettype, args); - - auto tvs = unify(expression, methodBaseType, callType, - "Function type '%1%' does not match invocation type '%2%'", - {methodBaseType, callType}); - if (tvs == nullptr) return expression; - - // Infer Dont_Care for type vars used only in not-present optional params - auto dontCares = new TypeVariableSubstitution(); - auto typeParams = methodBaseType->typeParameters; - for (auto p : *methodBaseType->parameters) { - if (!p->isOptional()) continue; - forAllMatching( - p, [tvs, dontCares, typeParams, this](const IR::Type_Var *tv) { - if (typeMap->getSubstitutions()->lookup(tv) != nullptr) - return; // already bound - if (tvs->lookup(tv)) return; // already bound - if (typeParams->getDeclByName(tv->name) != tv) return; // not a tv of this call - dontCares->setBinding(tv, IR::Type_Dontcare::get()); - }); - } - addSubstitutions(dontCares); - - LOG2("Method type before specialization " << methodType << " with " << tvs); - TypeVariableSubstitutionVisitor substVisitor(tvs); - substVisitor.setCalledBy(this); - auto specMethodType = methodType->apply(substVisitor); - LOG2("Method type after specialization " << specMethodType); - learn(specMethodType, this, getChildContext()); - - auto canon = getType(specMethodType); - if (canon == nullptr) return expression; - - auto functionType = specMethodType->to(); - BUG_CHECK(functionType != nullptr, "Method type is %1%", specMethodType); - - if (!functionType->is()) - BUG("Unexpected type for function %1%", functionType); - - auto returnType = tvs->lookup(rettype); - if (returnType == nullptr) { - typeError("Cannot infer a concrete return type for this call of %1%", expression); - return expression; - } - // The return type may also contain type variables - returnType = returnType->apply(substVisitor)->to(); - learn(returnType, this, getChildContext()); - if (returnType->is() || returnType->is() || - returnType->is() || returnType->is() || - returnType->is() || - (returnType->is() && !constArgs)) { - // Experimental: methods with all constant arguments can return an extern - // instance as a factory method evaluated at compile time. - typeError("%1%: illegal return type %2%", expression, returnType); - return expression; - } - - setType(getOriginal(), returnType); - setType(expression, returnType); - - ConstantTypeSubstitution cts(tvs, typeMap, this); - auto result = expression; - // Arguments may need to be cast, e.g., list expression to a - // header type. - auto paramIt = functionType->parameters->begin(); - auto newArgs = new IR::Vector(); - bool changed = false; - for (auto arg : *expression->arguments) { - cstring argName = arg->name.name; - bool named = !argName.isNullOrEmpty(); - const IR::Parameter *param; - - if (named) { - param = functionType->parameters->getParameter(argName); - } else { - param = *paramIt; - } - - if (param->type->is()) - typeError( - "%1%: Could not infer a type for parameter %2% " - "(inferred type is don't care '_')", - arg, param); - - // By calling generic functions with don't care parameters - // we can force parameters to have illegal types. Check here for this case. - // e.g., void f(in T arg); table t { }; f(t); - if (param->type->is() || param->type->is() || - param->type->is() || param->type->is() || - param->type->is()) - typeError("%1%: argument cannot have type %2%", arg, param->type); - - auto newExpr = arg->expression; - if (param->direction == IR::Direction::In) { - // This is like an assignment; may make additional conversions. - newExpr = assignment(arg, param->type, arg->expression); - } else { - // Insert casts for 'int' values. - newExpr = cts.convert(newExpr, getChildContext())->to(); - } - if (::errorCount() > 0) return expression; - if (newExpr != arg->expression) { - LOG2("Changing method argument to " << newExpr); - changed = true; - newArgs->push_back(new IR::Argument(arg->srcInfo, arg->name, newExpr)); - } else { - newArgs->push_back(arg); - } - if (!named) ++paramIt; - } - - if (changed) - result = new IR::MethodCallExpression(result->srcInfo, result->type, result->method, - result->typeArguments, newArgs); - setType(result, returnType); - - auto mi = MethodInstance::resolve(result, this, typeMap, getChildContext(), true); - if (mi->isApply() && findContext()) { - typeError("%1%: apply cannot be called from actions", expression); - return expression; - } - - if (const auto *ef = mi->to()) { - const IR::Type *baseReturnType = returnType; - if (const auto *sc = returnType->to()) - baseReturnType = sc->baseType; - const bool factoryOrStaticAssert = - baseReturnType->is() || ef->method->name == "static_assert"; - if (constArgs && factoryOrStaticAssert) { - // factory extern function calls (those that return extern objects) with constant - // args are compile-time constants. - // The result of a static_assert call is also a compile-time constant. - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - } - } - - auto bi = mi->to(); - if ((findContext()) && (!bi || (bi->name == IR::Type_Stack::pop_front || - bi->name == IR::Type_Stack::push_front))) { - typeError("%1%: no function calls allowed in this context", expression); - return expression; - } - return result; - } - return expression; -} - -const IR::Node *TypeInference::postorder(IR::ConstructorCallExpression *expression) { - if (done()) return expression; - auto type = getTypeType(expression->constructedType); - if (type == nullptr) return expression; - - auto simpleType = type; - CHECK_NULL(simpleType); - if (auto *sc = type->to()) simpleType = sc->substituted; - - if (auto *e = simpleType->to()) { - auto [contType, newArgs] = checkExternConstructor(expression, e, expression->arguments); - if (newArgs == nullptr) return expression; - expression->arguments = newArgs; - setType(getOriginal(), contType); - setType(expression, contType); - } else if (auto *c = simpleType->to()) { - auto typeAndArgs = containerInstantiation(expression, expression->arguments, c); - auto contType = typeAndArgs.first; - auto args = typeAndArgs.second; - if (contType == nullptr || args == nullptr) return expression; - if (auto *st = type->to()) { - contType = new IR::Type_SpecializedCanonical(type->srcInfo, st->baseType, st->arguments, - contType); - } - expression->arguments = args; - setType(expression, contType); - setType(getOriginal(), contType); - } else { - typeError("%1%: Cannot invoke a constructor on type %2%", expression, type->toString()); - } - - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - return expression; -} - -static void convertStructToTuple(const IR::Type_StructLike *structType, IR::Type_Tuple *tuple) { - for (auto field : structType->fields) { - if (auto ft = field->type->to()) { - tuple->components.push_back(ft); - } else if (auto ft = field->type->to()) { - convertStructToTuple(ft, tuple); - } else if (auto ft = field->type->to()) { - tuple->components.push_back(ft); - } else if (auto ft = field->type->to()) { - tuple->components.push_back(ft); - } else { - typeError("Type not supported %1% for struct field %2% in 'select'", field->type, - field); - } - } -} - -const IR::SelectCase *TypeInference::matchCase(const IR::SelectExpression *select, - const IR::Type_BaseList *selectType, - const IR::SelectCase *selectCase, - const IR::Type *caseType) { - // The selectType is always a tuple - // If the caseType is a set type, we unify the type of the set elements - if (auto *set = caseType->to()) caseType = set->elementType; - // The caseType may be a simple type, and then we have to unwrap the selectType - if (caseType->is()) return selectCase; - - if (auto *sl = caseType->to()) { - auto tupleType = new IR::Type_Tuple(); - convertStructToTuple(sl, tupleType); - caseType = tupleType; - } - const IR::Type *useSelType = selectType; - if (!caseType->is()) { - if (selectType->components.size() != 1) { - typeError("Type mismatch %1% (%2%) vs %3% (%4%)", select->select, - selectType->toString(), selectCase, caseType->toString()); - return nullptr; - } - useSelType = selectType->components.at(0); - } - auto tvs = unifyCast( - select, useSelType, caseType, - "'match' case label '%1%' has type '%2%' which does not match the expected type '%3%'", - {selectCase->keyset, caseType, useSelType}); - if (tvs == nullptr) return nullptr; - ConstantTypeSubstitution cts(tvs, typeMap, this); - auto ks = cts.convert(selectCase->keyset, getChildContext()); - if (::errorCount() > 0) return selectCase; - - if (ks != selectCase->keyset) - selectCase = new IR::SelectCase(selectCase->srcInfo, ks, selectCase->state); - return selectCase; -} - -const IR::Node *TypeInference::postorder(IR::This *expression) { - if (done()) return expression; - auto decl = findContext(); - if (findContext() == nullptr || decl == nullptr) - typeError("%1%: can only be used in the definition of an abstract method", expression); - auto type = getType(decl); - setType(expression, type); - setType(getOriginal(), type); - return expression; -} - -const IR::Node *TypeInference::postorder(IR::DefaultExpression *expression) { - if (!done()) { - setType(expression, IR::Type_Dontcare::get()); - setType(getOriginal(), IR::Type_Dontcare::get()); - } - setCompileTimeConstant(expression); - setCompileTimeConstant(getOriginal()); - return expression; -} - -bool TypeInference::containsHeader(const IR::Type *type) { - if (type->is() || type->is() || - type->is()) - return true; - if (auto *st = type->to()) { - for (auto f : st->fields) - if (containsHeader(f->type)) return true; - } - return false; -} - -/// Expressions that appear in a select expression are restricted to a small -/// number of types: bits, enums, serializable enums, and booleans. -static bool validateSelectTypes(const IR::Type *type, const IR::SelectExpression *expression) { - if (auto tuple = type->to()) { - for (auto ct : tuple->components) { - auto check = validateSelectTypes(ct, expression); - if (!check) return false; - } - return true; - } else if (type->is() || type->is() || - type->is() || type->is()) { - return true; - } - typeError("Expression '%1%' with a component of type '%2%' cannot be used in 'select'", - expression->select, type); - return false; -} - -const IR::Node *TypeInference::postorder(IR::SelectExpression *expression) { - if (done()) return expression; - auto selectType = getType(expression->select); - if (selectType == nullptr) return expression; - - // Check that the selectType is determined - auto tuple = selectType->to(); - BUG_CHECK(tuple != nullptr, "%1%: Expected a tuple type for the select expression, got %2%", - expression, selectType); - if (!validateSelectTypes(selectType, expression)) return expression; - - bool changes = false; - IR::Vector vec; - for (auto sc : expression->selectCases) { - auto type = getType(sc->keyset); - if (type == nullptr) return expression; - auto newsc = matchCase(expression, tuple, sc, type); - vec.push_back(newsc); - if (newsc != sc) changes = true; - } - if (changes) - expression = - new IR::SelectExpression(expression->srcInfo, expression->select, std::move(vec)); - setType(expression, IR::Type_State::get()); - setType(getOriginal(), IR::Type_State::get()); - return expression; -} - -const IR::Node *TypeInference::postorder(IR::AttribLocal *local) { - setType(local, local->type); - setType(getOriginal(), local->type); - return local; -} - -///////////////////////////////////////// Statements et al. - -const IR::Node *TypeInference::postorder(IR::IfStatement *conditional) { - LOG3("TI Visiting " << dbp(getOriginal())); - auto type = getType(conditional->condition); - if (type == nullptr) return conditional; - if (!type->is()) - typeError("Condition of %1% does not evaluate to a bool but %2%", conditional, - type->toString()); - return conditional; -} - -const IR::Node *TypeInference::postorder(IR::SwitchStatement *stat) { - LOG3("TI Visiting " << dbp(getOriginal())); - auto type = getType(stat->expression); - if (type == nullptr) return stat; - - if (auto ae = type->to()) { - // switch (table.apply(...)) - absl::flat_hash_map foundLabels; - const IR::Node *foundDefault = nullptr; - for (auto c : stat->cases) { - if (c->label->is()) { - if (foundDefault) - typeError("%1%: multiple 'default' labels %2%", c->label, foundDefault); - foundDefault = c->label; - continue; - } else if (auto pe = c->label->to()) { - cstring label = pe->path->name.name; - auto [it, inserted] = foundLabels.emplace(label, c->label); - if (!inserted) - typeError("%1%: 'switch' label duplicates %2%", c->label, it->second); - if (!ae->contains(label)) - typeError("%1% is not a legal label (action name)", c->label); - } else { - typeError("%1%: 'switch' label must be an action name or 'default'", c->label); - } - } - } else { - // switch (expression) - Comparison comp; - comp.left = stat->expression; - if (isCompileTimeConstant(stat->expression)) - warn(ErrorType::WARN_MISMATCH, "%1%: constant expression in switch", stat->expression); - - for (auto &c : stat->cases) { - if (!isCompileTimeConstant(c->label)) - typeError("%1%: must be a compile-time constant", c->label); - auto lt = getType(c->label); - if (lt == nullptr) continue; - if (lt->is() && type->is()) { - c = new IR::SwitchCase(c->srcInfo, new IR::Cast(c->label->srcInfo, type, c->label), - c->statement); - setType(c->label, type); - setCompileTimeConstant(c->label); - continue; - } else if (c->label->is()) { - continue; - } - comp.right = c->label; - bool b = compare(stat, type, lt, &comp); - if (b && comp.right != c->label) { - c = new IR::SwitchCase(c->srcInfo, comp.right, c->statement); - setCompileTimeConstant(c->label); - } - } - } - return stat; -} - -const IR::Node *TypeInference::postorder(IR::ReturnStatement *statement) { - LOG3("TI Visiting " << dbp(getOriginal())); - auto func = findOrigCtxt(); - if (func == nullptr) { - if (statement->expression != nullptr) - typeError("%1%: return with expression can only be used in a function", statement); - return statement; - } - - auto ftype = getType(func); - if (ftype == nullptr) return statement; - - BUG_CHECK(ftype->is(), "%1%: expected a method type for function", ftype); - auto mt = ftype->to(); - auto returnType = mt->returnType; - CHECK_NULL(returnType); - if (returnType->is()) { - if (statement->expression != nullptr) - typeError("%1%: return expression in function with void return", statement); - return statement; - } - - if (statement->expression == nullptr) { - typeError("%1%: return with no expression in a function returning %2%", statement, - returnType->toString()); - return statement; - } - - auto init = assignment(statement, returnType, statement->expression); - if (init != statement->expression) statement->expression = init; - return statement; -} - -const IR::Node *TypeInference::postorder(IR::AssignmentStatement *assign) { - LOG3("TI Visiting " << dbp(getOriginal())); - auto ltype = getType(assign->left); - if (ltype == nullptr) return assign; - - if (!isLeftValue(assign->left)) { - typeError("Expression %1% cannot be the target of an assignment", assign->left); - LOG2(assign->left); - return assign; - } - - auto newInit = assignment(assign, ltype, assign->right); - if (newInit != assign->right) - assign = new IR::AssignmentStatement(assign->srcInfo, assign->left, newInit); - return assign; -} - -const IR::Node *TypeInference::postorder(IR::ForInStatement *forin) { - LOG3("TI Visiting " << dbp(getOriginal())); - auto ltype = getType(forin->ref); - if (ltype == nullptr) return forin; - auto ctype = getType(forin->collection); - if (ctype == nullptr) return forin; - - if (!isLeftValue(forin->ref)) { - typeError("Expression %1% cannot be the target of an assignment", forin->ref); - LOG2(forin->ref); - return forin; - } - if (auto range = forin->collection->to()) { - auto rclone = range->clone(); - rclone->left = assignment(forin, ltype, rclone->left); - rclone->right = assignment(forin, ltype, rclone->right); - if (*range != *rclone) - forin->collection = rclone; - else - delete rclone; - } else if (auto *stack = ctype->to()) { - if (!canCastBetween(stack->elementType, ltype)) - typeError("%1% does not match header stack type %2%", forin->ref, ctype); - } else if (auto *list = ctype->to()) { - if (!canCastBetween(list->elementType, ltype)) - typeError("%1% does not match %2% element type", forin->ref, ctype); - } else { - error(ErrorType::ERR_UNSUPPORTED, - "%1%Typechecking does not support iteration over this collection of type %2%", - forin->collection->srcInfo, ctype); - } - return forin; -} - -const IR::Node *TypeInference::postorder(IR::ActionListElement *elem) { - if (done()) return elem; - auto type = getType(elem->expression); - if (type == nullptr) return elem; - - setType(elem, type); - setType(getOriginal(), type); - return elem; -} - -const IR::Node *TypeInference::postorder(IR::SelectCase *sc) { - auto type = getType(sc->state); - if (type != nullptr && type != IR::Type_State::get()) typeError("%1% must be state", sc); - return sc; -} - -const IR::Node *TypeInference::postorder(IR::KeyElement *elem) { - auto ktype = getType(elem->expression); - if (ktype == nullptr) return elem; - while (ktype->is()) ktype = getTypeType(ktype->to()->type); - if (!ktype->is() && !ktype->is() && - !ktype->is() && !ktype->is() && - !ktype->is()) - typeError("Key %1% field type must be a scalar type; it cannot be %2%", elem->expression, - ktype->toString()); - auto type = getType(elem->matchType); - if (type != nullptr && type != IR::Type_MatchKind::get()) - typeError("%1% must be a %2% value", elem->matchType, - IR::Type_MatchKind::get()->toString()); - if (isCompileTimeConstant(elem->expression) && !readOnly) - warn(ErrorType::WARN_IGNORE_PROPERTY, "%1%: constant key element", elem); - return elem; -} - -const IR::Node *TypeInference::postorder(IR::ActionList *al) { - LOG3("TI Visited " << dbp(al)); - BUG_CHECK(currentActionList == nullptr, "%1%: nested action list?", al); - currentActionList = al; - return al; -} - -const IR::ActionListElement *TypeInference::validateActionInitializer( - const IR::Expression *actionCall) { - // We cannot retrieve the action list from the table, because the - // table has not been modified yet. We want the latest version of - // the action list, as it has been already typechecked. - auto al = currentActionList; - if (al == nullptr) { - auto table = findContext(); - BUG_CHECK(table, "%1%: not within a table", actionCall); - typeError("%1% has no action list, so it cannot invoke '%2%'", table, actionCall); - return nullptr; - } - - auto call = actionCall->to(); - if (call == nullptr) { - typeError("%1%: expected an action call", actionCall); - return nullptr; - } - auto method = call->method; - if (!method->is()) BUG("%1%: unexpected expression", method); - auto pe = method->to(); - auto decl = getDeclaration(pe->path, !errorOnNullDecls); - if (errorOnNullDecls && decl == nullptr) { - typeError("%1%: Cannot resolve declaration", pe); - return nullptr; - } - - auto ale = al->actionList.getDeclaration(decl->getName()); - if (ale == nullptr) { - typeError("%1% not present in action list", call); - return nullptr; - } - - BUG_CHECK(ale->is(), "%1%: expected an ActionListElement", ale); - auto elem = ale->to(); - auto entrypath = elem->getPath(); - auto entrydecl = getDeclaration(entrypath, true); - if (entrydecl != decl) { - typeError("%1% and %2% refer to different actions", actionCall, elem); - return nullptr; - } - - // Check that the data-plane parameters - // match the data-plane parameters for the same action in - // the actions list. - auto actionListCall = elem->expression->to(); - CHECK_NULL(actionListCall); - auto type = typeMap->getType(actionListCall->method); - if (type == nullptr) { - typeError("%1%: action invocation should be after the `actions` list", actionCall); - return nullptr; - } - - if (actionListCall->arguments->size() > call->arguments->size()) { - typeError("%1%: not enough arguments", call); - return nullptr; - } - - SameExpression se(this, typeMap); - auto callInstance = MethodInstance::resolve(call, this, typeMap, getChildContext(), true); - auto listInstance = - MethodInstance::resolve(actionListCall, this, typeMap, getChildContext(), true); - - for (auto param : *listInstance->substitution.getParametersInArgumentOrder()) { - auto aa = listInstance->substitution.lookup(param); - auto da = callInstance->substitution.lookup(param); - if (da == nullptr) { - typeError("%1%: parameter should be assigned in call %2%", param, call); - return nullptr; - } - bool same = se.sameExpression(aa->expression, da->expression); - if (!same) { - typeError("%1%: argument does not match declaration in actions list: %2%", da, aa); - return nullptr; - } - } - - for (auto param : *callInstance->substitution.getParametersInOrder()) { - auto da = callInstance->substitution.lookup(param); - if (da == nullptr) { - typeError("%1%: parameter should be assigned in call %2%", param, call); - return nullptr; - } - } - - return elem; -} - -const IR::Node *TypeInference::postorder(IR::Property *prop) { - // Handle the default_action - if (prop->name == IR::TableProperties::defaultActionPropertyName) { - auto pv = prop->value->to(); - if (pv == nullptr) { - typeError("%1% table property should be an action", prop); - } else { - auto type = getType(pv->expression); - if (type == nullptr) return prop; - if (!type->is()) { - typeError("%1% table property should be an action", prop); - return prop; - } - auto at = type->to(); - if (at->parameters->size() != 0) { - typeError("%1%: parameter %2% does not have a corresponding argument", prop->value, - at->parameters->parameters.at(0)); - return prop; - } - - // Check that the default action appears in the list of actions. - BUG_CHECK(prop->value->is(), "%1% not an expression", prop); - auto def = prop->value->to()->expression; - auto ale = validateActionInitializer(def); - if (ale != nullptr) { - auto anno = ale->getAnnotation(IR::Annotation::tableOnlyAnnotation); - if (anno != nullptr) { - typeError("%1%: Action marked with %2% used as default action", prop, - IR::Annotation::tableOnlyAnnotation); - return prop; - } - } - } - } - return prop; -} - } // namespace P4 diff --git a/frontends/p4/typeChecking/typeChecker.h b/frontends/p4/typeChecking/typeChecker.h index 5665dbc26cb..ba533670753 100644 --- a/frontends/p4/typeChecking/typeChecker.h +++ b/frontends/p4/typeChecking/typeChecker.h @@ -19,7 +19,6 @@ limitations under the License. #include "frontends/common/resolveReferences/referenceMap.h" #include "frontends/common/resolveReferences/resolveReferences.h" -#include "frontends/p4/typeChecking/typeSubstitution.h" #include "frontends/p4/typeMap.h" #include "ir/ir.h" #include "ir/pass_manager.h" @@ -152,9 +151,8 @@ class TypeInference : public Transform, public ResolutionContext { * Made virtual to enable private midend passes to extend standard IR with custom IR classes. */ virtual const IR::Type *canonicalize(const IR::Type *type); - const IR::Type *canonicalizeFields( - const IR::Type_StructLike *type, - std::function *)> constructor); + template + const IR::Type *canonicalizeFields(const IR::Type_StructLike *type, Ctor constructor); virtual const IR::ParameterList *canonicalizeParameters(const IR::ParameterList *params); // various helpers diff --git a/frontends/p4/typeChecking/typeConstraints.h b/frontends/p4/typeChecking/typeConstraints.h index 74319040f05..36e17e7bc48 100644 --- a/frontends/p4/typeChecking/typeConstraints.h +++ b/frontends/p4/typeChecking/typeConstraints.h @@ -28,7 +28,7 @@ namespace P4 { /// Creates a string that describes the values of current type variables class Explain : public Inspector { - std::set explained; + absl::flat_hash_set explained; const TypeVariableSubstitution *subst; public: @@ -39,10 +39,10 @@ class Explain : public Inspector { return Inspector::init_apply(node); } void postorder(const IR::Type_Var *tv) override { - if (explained.find(tv) != explained.end()) + auto [_, inserted] = explained.emplace(tv); + if (!inserted) // Do not repeat explanations. return; - explained.emplace(tv); auto val = subst->lookup(tv); if (!val) return; explanation += "Where '" + tv->toString() + "' is bound to '" + val->toString() + "'\n"; @@ -191,7 +191,7 @@ class TypeConstraints final : public IHasDbPrint { * This example should not typecheck: because T cannot be constrained in the invocation of f. * While typechecking the f(data) call, T is not a type variable that can be unified. */ - std::set unifiableTypeVariables; + absl::flat_hash_set unifiableTypeVariables; std::vector constraints; TypeUnification *unification; const TypeVariableSubstitution *definedVariables;