diff --git a/svf-llvm/include/SVF-LLVM/ObjTypeInference.h b/svf-llvm/include/SVF-LLVM/ObjTypeInference.h index 51bc2dcd3..951be7d86 100644 --- a/svf-llvm/include/SVF-LLVM/ObjTypeInference.h +++ b/svf-llvm/include/SVF-LLVM/ObjTypeInference.h @@ -35,8 +35,10 @@ #include "SVFIR/SVFValue.h" #include "Util/ThreadAPI.h" -namespace SVF { -class ObjTypeInference { +namespace SVF +{ +class ObjTypeInference +{ public: typedef Set ValueSet; @@ -77,12 +79,14 @@ class ObjTypeInference { const Type *defaultType(const Value *val); /// pointer type - inline const Type *ptrType() { + inline const Type *ptrType() + { return PointerType::getUnqual(getLLVMCtx()); } /// int8 type - inline const IntegerType *int8Type() { + inline const IntegerType *int8Type() + { return Type::getInt8Ty(getLLVMCtx()); } diff --git a/svf-llvm/lib/CHGBuilder.cpp b/svf-llvm/lib/CHGBuilder.cpp index 0a1ed7ffc..588492cb4 100644 --- a/svf-llvm/lib/CHGBuilder.cpp +++ b/svf-llvm/lib/CHGBuilder.cpp @@ -687,15 +687,18 @@ const CHGraph::CHNodeSetTy& CHGBuilder::getCSClasses(const CallBase* cs) { Set thisPtrClassNames = getClassNameOfThisPtr(cs); - if(thisPtrClassNames.empty()) { + if(thisPtrClassNames.empty()) + { // if we cannot infer classname, conservatively push all class nodes - for (const auto &node: *chg) { + for (const auto &node: *chg) + { chg->csToClassesMap[svfcall].insert(node.second); } return chg->csToClassesMap[svfcall]; } - for (const auto &thisPtrClassName: thisPtrClassNames) { + for (const auto &thisPtrClassName: thisPtrClassNames) + { if (const CHNode* thisNode = chg->getNode(thisPtrClassName)) { const CHGraph::CHNodeSetTy& instAndDesces = getInstancesAndDescendants(thisPtrClassName); diff --git a/svf-llvm/lib/CppUtil.cpp b/svf-llvm/lib/CppUtil.cpp index 46b7da9b2..5cb2a4752 100644 --- a/svf-llvm/lib/CppUtil.cpp +++ b/svf-llvm/lib/CppUtil.cpp @@ -398,7 +398,8 @@ const Argument* cppUtil::getConstructorThisPtr(const Function* fun) return thisPtr; } -void updateClassNameBeforeBrackets(cppUtil::DemangledName& dname) { +void updateClassNameBeforeBrackets(cppUtil::DemangledName& dname) +{ dname.funcName = cppUtil::getBeforeBrackets(dname.funcName); dname.className = cppUtil::getBeforeBrackets(dname.className); size_t colon = dname.className.rfind("::"); @@ -409,7 +410,7 @@ void updateClassNameBeforeBrackets(cppUtil::DemangledName& dname) { else { dname.className = - cppUtil::getBeforeBrackets(dname.className.substr(colon + 2)); + cppUtil::getBeforeBrackets(dname.className.substr(colon + 2)); } } @@ -481,7 +482,8 @@ bool cppUtil::VCallInCtorOrDtor(const CallBase* cs) { Set classNameOfThisPtrs = cppUtil::getClassNameOfThisPtr(cs); const Function* func = cs->getCaller(); - for (const auto &classNameOfThisPtr: classNameOfThisPtrs) { + for (const auto &classNameOfThisPtr: classNameOfThisPtrs) + { if (cppUtil::isConstructor(func) || cppUtil::isDestructor(func)) { cppUtil::DemangledName dname = cppUtil::demangle(func->getName().str()); @@ -542,16 +544,19 @@ Set cppUtil::getClassNameOfThisPtr(const CallBase* inst) Set ans; std::transform(thisPtrNames.begin(), thisPtrNames.end(), std::inserter(ans, ans.begin()), - [](const std::string &thisPtrName) -> std::string { - size_t found = thisPtrName.find_last_not_of("0123456789"); - if (found != std::string::npos) { - if (found != thisPtrName.size() - 1 && - thisPtrName[found] == '.') { - return thisPtrName.substr(0, found); - } - } - return thisPtrName; - }); + [](const std::string &thisPtrName) -> std::string + { + size_t found = thisPtrName.find_last_not_of("0123456789"); + if (found != std::string::npos) + { + if (found != thisPtrName.size() - 1 && + thisPtrName[found] == '.') + { + return thisPtrName.substr(0, found); + } + } + return thisPtrName; + }); return ans; } @@ -626,15 +631,19 @@ bool LLVMUtil::isConstantObjSym(const Value* val) * @param foo * @return */ -Set cppUtil::extractClsNamesFromFunc(const Function *foo) { +Set cppUtil::extractClsNamesFromFunc(const Function *foo) +{ assert(foo->hasName() && "foo does not have a name? possible indirect call"); const std::string &name = foo->getName().str(); - if (isConstructor(foo)) { + if (isConstructor(foo)) + { // c++ constructor DemangledName demangledName = cppUtil::demangle(name); updateClassNameBeforeBrackets(demangledName); return {demangledName.className}; - } else if (isTemplateFunc(foo)) { + } + else if (isTemplateFunc(foo)) + { // array index Set classNames = extractClsNamesFromTemplate(name); assert(!classNames.empty() && "empty class names?"); @@ -649,29 +658,38 @@ Set cppUtil::extractClsNamesFromFunc(const Function *foo) { * @param input * @return */ -std::vector findInnermostBrackets(const std::string &input) { +std::vector findInnermostBrackets(const std::string &input) +{ typedef std::pair StEdIdxPair; std::stack stack; std::vector innerMostPairs; std::vector used(input.length(), false); - for (u32_t i = 0; i < input.length(); ++i) { - if (input[i] == '<') { + for (u32_t i = 0; i < input.length(); ++i) + { + if (input[i] == '<') + { stack.push(i); - } else if (input[i] == '>' && i > 0 && input[i - 1] != '-') { - if (!stack.empty()) { + } + else if (input[i] == '>' && i > 0 && input[i - 1] != '-') + { + if (!stack.empty()) + { int openIndex = stack.top(); stack.pop(); // Check if this pair is innermost bool isInnermost = true; - for (u32_t j = openIndex + 1; j < i && isInnermost; ++j) { - if (used[j]) { + for (u32_t j = openIndex + 1; j < i && isInnermost; ++j) + { + if (used[j]) + { isInnermost = false; } } - if (isInnermost) { + if (isInnermost) + { innerMostPairs.emplace_back(openIndex, i); used[openIndex] = used[i] = true; // Mark these indices as used } @@ -679,7 +697,8 @@ std::vector findInnermostBrackets(const std::string &input) { } } std::vector ans(innerMostPairs.size()); - std::transform(innerMostPairs.begin(), innerMostPairs.end(), ans.begin(), [&input](StEdIdxPair &p) -> std::string { + std::transform(innerMostPairs.begin(), innerMostPairs.end(), ans.begin(), [&input](StEdIdxPair &p) -> std::string + { return input.substr(p.first + 1, p.second - p.first - 1); }); return ans; @@ -690,22 +709,27 @@ std::vector findInnermostBrackets(const std::string &input) { * @param str * @return */ -std::string stripWhitespaces(const std::string &str) { - auto start = std::find_if(str.begin(), str.end(), [](unsigned char ch) { +std::string stripWhitespaces(const std::string &str) +{ + auto start = std::find_if(str.begin(), str.end(), [](unsigned char ch) + { return !std::isspace(ch); }); - auto end = std::find_if(str.rbegin(), str.rend(), [](unsigned char ch) { + auto end = std::find_if(str.rbegin(), str.rend(), [](unsigned char ch) + { return !std::isspace(ch); }).base(); return (start < end) ? std::string(start, end) : std::string(); } -std::vector splitAndStrip(const std::string &input, char delimiter) { +std::vector splitAndStrip(const std::string &input, char delimiter) +{ std::vector tokens; size_t start = 0, end = 0; - while ((end = input.find(delimiter, start)) != std::string::npos) { + while ((end = input.find(delimiter, start)) != std::string::npos) + { tokens.push_back(stripWhitespaces(input.substr(start, end - start))); start = end + 1; } @@ -720,21 +744,27 @@ std::vector splitAndStrip(const std::string &input, char delimiter) * @param oname * @return */ -Set cppUtil::extractClsNamesFromTemplate(const std::string &oname) { +Set cppUtil::extractClsNamesFromTemplate(const std::string &oname) +{ // "std::array" -> A // "std::queue > >" -> A // __gnu_cxx::__aligned_membuf >::_M_ptr() const -> A Set ans; std::string demangleName = llvm::demangle(oname); std::vector innermosts = findInnermostBrackets(demangleName); - for (const auto &innermost: innermosts) { + for (const auto &innermost: innermosts) + { const std::vector &allstrs = splitAndStrip(innermost, ','); - for (const auto &str: allstrs) { + for (const auto &str: allstrs) + { size_t spacePos = str.find(' '); - if (spacePos != std::string::npos) { + if (spacePos != std::string::npos) + { // A const* -> A ans.insert(str.substr(0, spacePos)); - } else { + } + else + { size_t starPos = str.find('*'); if (starPos != std::string::npos) // A* -> A @@ -754,8 +784,10 @@ Set cppUtil::extractClsNamesFromTemplate(const std::string &oname) * @param val * @return */ -bool cppUtil::isClsNameSource(const Value *val) { - if (const auto *callBase = SVFUtil::dyn_cast(val)) { +bool cppUtil::isClsNameSource(const Value *val) +{ + if (const auto *callBase = SVFUtil::dyn_cast(val)) + { const Function *foo = callBase->getCalledFunction(); return isConstructor(foo) || isDestructor(foo) || isTemplateFunc(foo) || isDynCast(foo); } @@ -768,7 +800,8 @@ bool cppUtil::isClsNameSource(const Value *val) { * @param label * @return */ -bool cppUtil::matchesLabel(const std::string &foo, const std::string &label) { +bool cppUtil::matchesLabel(const std::string &foo, const std::string &label) +{ return foo.compare(0, label.size(), label) == 0; } @@ -777,7 +810,8 @@ bool cppUtil::matchesLabel(const std::string &foo, const std::string &label) { * @param foo * @return */ -bool cppUtil::isTemplateFunc(const Function *foo) { +bool cppUtil::isTemplateFunc(const Function *foo) +{ assert(foo->hasName() && "foo does not have a name? possible indirect call"); const std::string &name = foo->getName().str(); return matchesLabel(name, znstLabel) || matchesLabel(name, znkstLabel) || @@ -789,7 +823,8 @@ bool cppUtil::isTemplateFunc(const Function *foo) { * @param foo * @return */ -bool cppUtil::isDynCast(const Function *foo) { +bool cppUtil::isDynCast(const Function *foo) +{ assert(foo->hasName() && "foo does not have a name? possible indirect call"); return foo->getName().str() == dyncast; } @@ -799,7 +834,8 @@ bool cppUtil::isDynCast(const Function *foo) { * @param foo * @return */ -bool cppUtil::isNewAlloc(const Function *foo) { +bool cppUtil::isNewAlloc(const Function *foo) +{ assert(foo->hasName() && "foo does not have a name? possible indirect call"); return foo->getName().str() == znwm; } @@ -809,7 +845,8 @@ bool cppUtil::isNewAlloc(const Function *foo) { * @param callBase * @return */ -std::string cppUtil::extractClsNameFromDynCast(const CallBase* callBase) { +std::string cppUtil::extractClsNameFromDynCast(const CallBase* callBase) +{ Value *tgtCast = callBase->getArgOperand(2); const std::string &valueStr = LLVMUtil::dumpValue(tgtCast); u32_t leftPos = valueStr.find(ztilabel); @@ -819,22 +856,25 @@ std::string cppUtil::extractClsNameFromDynCast(const CallBase* callBase) { const std::string &substr = valueStr.substr(leftPos, rightPos - leftPos); std::string demangleName = llvm::demangle(substr); const std::string &realName = demangleName.substr(ztiprefix.size(), - demangleName.size() - ztiprefix.size()); + demangleName.size() - ztiprefix.size()); assert(realName != "" && "real name for dyncast empty?"); return realName; } -const Type *cppUtil::cppClsNameToType(const std::string &className) { +const Type *cppUtil::cppClsNameToType(const std::string &className) +{ StructType *classTy = StructType::getTypeByName(LLVMModuleSet::getLLVMModuleSet()->getContext(), - clsName + className); + clsName + className); return classTy ? classTy : LLVMModuleSet::getLLVMModuleSet()->getTypeInference()->ptrType(); } -std::string cppUtil::typeToClsName(const Type *ty) { - if (const auto *stTy = SVFUtil::dyn_cast(ty)) { +std::string cppUtil::typeToClsName(const Type *ty) +{ + if (const auto *stTy = SVFUtil::dyn_cast(ty)) + { const std::string &typeName = stTy->getName().str(); const std::string &className = typeName.substr( - clsName.size(), typeName.size() - clsName.size()); + clsName.size(), typeName.size() - clsName.size()); return className; } return ""; diff --git a/svf-llvm/lib/ObjTypeInference.cpp b/svf-llvm/lib/ObjTypeInference.cpp index cc5f5cb4a..a716b2516 100644 --- a/svf-llvm/lib/ObjTypeInference.cpp +++ b/svf-llvm/lib/ObjTypeInference.cpp @@ -80,34 +80,48 @@ const std::string TYPEMALLOC = "TYPE_MALLOC"; /// Determine type based on infer site /// https://llvm.org/docs/OpaquePointers.html#migration-instructions -const Type *infersiteToType(const Value *val) { +const Type *infersiteToType(const Value *val) +{ assert(val && "value cannot be empty"); - if (SVFUtil::isa(val)) { + if (SVFUtil::isa(val)) + { return llvm::getLoadStoreType(const_cast(val)); - } else if (const auto *gepInst = SVFUtil::dyn_cast(val)) { + } + else if (const auto *gepInst = SVFUtil::dyn_cast(val)) + { return gepInst->getSourceElementType(); - } else if (const auto *call = SVFUtil::dyn_cast(val)) { + } + else if (const auto *call = SVFUtil::dyn_cast(val)) + { return call->getFunctionType(); - } else if (const auto *allocaInst = SVFUtil::dyn_cast(val)) { + } + else if (const auto *allocaInst = SVFUtil::dyn_cast(val)) + { return allocaInst->getAllocatedType(); - } else if (const auto *globalValue = SVFUtil::dyn_cast(val)) { + } + else if (const auto *globalValue = SVFUtil::dyn_cast(val)) + { return globalValue->getValueType(); - } else { + } + else + { ABORT_MSG("unknown value:" + dumpValueAndDbgInfo(val)); } } -const Type *ObjTypeInference::defaultType(const Value *val) { +const Type *ObjTypeInference::defaultType(const Value *val) +{ ABORT_IFNOT(val, "val cannot be null"); // heap has a default type of 8-bit integer type if (SVFUtil::isa(val) && SVFUtil::isHeapAllocExtCallViaRet( - LLVMModuleSet::getLLVMModuleSet()->getSVFInstruction(SVFUtil::cast(val)))) + LLVMModuleSet::getLLVMModuleSet()->getSVFInstruction(SVFUtil::cast(val)))) return int8Type(); // otherwise we return a pointer type in the default address space return ptrType(); } -LLVMContext &ObjTypeInference::getLLVMCtx() { +LLVMContext &ObjTypeInference::getLLVMCtx() +{ return LLVMModuleSet::getLLVMModuleSet()->getContext(); } @@ -117,15 +131,20 @@ LLVMContext &ObjTypeInference::getLLVMCtx() { * if not, find allocations and then forward get or infer types * @param val */ -const Type *ObjTypeInference::inferObjType(const Value *var) { +const Type *ObjTypeInference::inferObjType(const Value *var) +{ if (isAlloc(var)) return fwInferObjType(var); Set &sources = bwfindAllocOfVar(var); Set types; - if (sources.empty()) { + if (sources.empty()) + { // cannot find allocation, try to fw infer starting from var types.insert(fwInferObjType(var)); - } else { - for (const auto &source: sources) { + } + else + { + for (const auto &source: sources) + { types.insert(fwInferObjType(source)); } } @@ -138,10 +157,12 @@ const Type *ObjTypeInference::inferObjType(const Value *var) { * forward infer the type of the object pointed by var * @param var */ -const Type *ObjTypeInference::fwInferObjType(const Value *var) { +const Type *ObjTypeInference::fwInferObjType(const Value *var) +{ // consult cache auto tIt = _valueToType.find(var); - if (tIt != _valueToType.end()) { + if (tIt != _valueToType.end()) + { return tIt->second ? tIt->second : defaultType(var); } @@ -150,7 +171,8 @@ const Type *ObjTypeInference::fwInferObjType(const Value *var) { Set visited; workList.push({var, false}); - while (!workList.empty()) { + while (!workList.empty()) + { auto curPair = workList.pop(); if (visited.count(curPair)) continue; visited.insert(curPair); @@ -158,26 +180,35 @@ const Type *ObjTypeInference::fwInferObjType(const Value *var) { bool canUpdate = curPair.second; Set infersites; - auto insertInferSite = [&infersites, &canUpdate](const Value *infersite) { + auto insertInferSite = [&infersites, &canUpdate](const Value *infersite) + { if (canUpdate) infersites.insert(infersite); }; - auto insertInferSitesOrPushWorklist = [this, &infersites, &workList, &canUpdate](const auto &pUser) { + auto insertInferSitesOrPushWorklist = [this, &infersites, &workList, &canUpdate](const auto &pUser) + { auto vIt = _valueToInferSites.find(pUser); - if (canUpdate) { - if (vIt != _valueToInferSites.end()) { + if (canUpdate) + { + if (vIt != _valueToInferSites.end()) + { infersites.insert(vIt->second.begin(), vIt->second.end()); } - } else { + } + else + { if (vIt == _valueToInferSites.end()) workList.push({pUser, false}); } }; - if (!canUpdate && !_valueToInferSites.count(curValue)) { + if (!canUpdate && !_valueToInferSites.count(curValue)) + { workList.push({curValue, true}); } if (const auto *gepInst = SVFUtil::dyn_cast(curValue)) insertInferSite(gepInst); - for (const auto &it: curValue->uses()) { - if (const auto *loadInst = SVFUtil::dyn_cast(it.getUser())) { + for (const auto &it: curValue->uses()) + { + if (const auto *loadInst = SVFUtil::dyn_cast(it.getUser())) + { /* * infer based on load, e.g., %call = call i8* malloc() @@ -185,8 +216,11 @@ const Type *ObjTypeInference::fwInferObjType(const Value *var) { %q = load %struct.MyStruct, %struct.MyStruct* %1 */ insertInferSite(loadInst); - } else if (const auto *storeInst = SVFUtil::dyn_cast(it.getUser())) { - if (storeInst->getPointerOperand() == curValue) { + } + else if (const auto *storeInst = SVFUtil::dyn_cast(it.getUser())) + { + if (storeInst->getPointerOperand() == curValue) + { /* * infer based on store (pointer operand), e.g., %call = call i8* malloc() @@ -194,8 +228,11 @@ const Type *ObjTypeInference::fwInferObjType(const Value *var) { store %struct.MyStruct .., %struct.MyStruct* %1 */ insertInferSite(storeInst); - } else { - for (const auto &nit: storeInst->getPointerOperand()->uses()) { + } + else + { + for (const auto &nit: storeInst->getPointerOperand()->uses()) + { /* * propagate across store (value operand) and load %call = call i8* malloc() @@ -219,17 +256,22 @@ const Type *ObjTypeInference::fwInferObjType(const Value *var) { infer site -> %f1 = getelementptr inbounds %struct.MyStruct, %struct.MyStruct* %6, i32 0, i32 0, !dbg !50 */ if (const auto *gepInst = SVFUtil::dyn_cast( - storeInst->getPointerOperand())) { + storeInst->getPointerOperand())) + { const Value *gepBase = gepInst->getPointerOperand(); if (!SVFUtil::isa(gepBase)) continue; const auto *load = SVFUtil::dyn_cast(gepBase); - for (const auto &loadUse: load->getPointerOperand()->uses()) { + for (const auto &loadUse: load->getPointerOperand()->uses()) + { if (loadUse.getUser() == load || !SVFUtil::isa(loadUse.getUser())) continue; - for (const auto &gepUse: loadUse.getUser()->uses()) { + for (const auto &gepUse: loadUse.getUser()->uses()) + { if (!SVFUtil::isa(gepUse.getUser())) continue; - for (const auto &loadUse2: gepUse.getUser()->uses()) { - if (SVFUtil::isa(loadUse2.getUser())) { + for (const auto &loadUse2: gepUse.getUser()->uses()) + { + if (SVFUtil::isa(loadUse2.getUser())) + { insertInferSitesOrPushWorklist(loadUse2.getUser()); } } @@ -239,7 +281,9 @@ const Type *ObjTypeInference::fwInferObjType(const Value *var) { } } - } else if (const auto *gepInst = SVFUtil::dyn_cast(it.getUser())) { + } + else if (const auto *gepInst = SVFUtil::dyn_cast(it.getUser())) + { /* * infer based on gep (pointer operand) %call = call i8* malloc() @@ -248,13 +292,19 @@ const Type *ObjTypeInference::fwInferObjType(const Value *var) { */ if (gepInst->getPointerOperand() == curValue) insertInferSite(gepInst); - } else if (const auto *bitcast = SVFUtil::dyn_cast(it.getUser())) { + } + else if (const auto *bitcast = SVFUtil::dyn_cast(it.getUser())) + { // continue on bitcast insertInferSitesOrPushWorklist(bitcast); - } else if (const auto *phiNode = SVFUtil::dyn_cast(it.getUser())) { + } + else if (const auto *phiNode = SVFUtil::dyn_cast(it.getUser())) + { // continue on bitcast insertInferSitesOrPushWorklist(phiNode); - } else if (const auto *retInst = SVFUtil::dyn_cast(it.getUser())) { + } + else if (const auto *retInst = SVFUtil::dyn_cast(it.getUser())) + { /* * propagate from return to caller Function Attrs: noinline nounwind optnone uwtable @@ -266,15 +316,19 @@ const Type *ObjTypeInference::fwInferObjType(const Value *var) { %call = call i8* @malloc_wrapper() ..infer based on %call.. */ - for (const auto &callsite: retInst->getFunction()->uses()) { - if (const auto *callBase = SVFUtil::dyn_cast(callsite.getUser())) { + for (const auto &callsite: retInst->getFunction()->uses()) + { + if (const auto *callBase = SVFUtil::dyn_cast(callsite.getUser())) + { // skip function as parameter // e.g., call void @foo(%struct.ssl_ctx_st* %9, i32 (i8*, i32, i32, i8*)* @passwd_callback) if (callBase->getCalledFunction() != retInst->getFunction()) continue; insertInferSitesOrPushWorklist(callBase); } } - } else if (const auto *callBase = SVFUtil::dyn_cast(it.getUser())) { + } + else if (const auto *callBase = SVFUtil::dyn_cast(it.getUser())) + { /* * propagate from callsite to callee %call = call i8* @malloc(i32 noundef 16) @@ -291,17 +345,20 @@ const Type *ObjTypeInference::fwInferObjType(const Value *var) { // skip indirect call // e.g., %0 = ... -> call %0(...) if (!callBase->hasArgument(curValue)) continue; - if (Function *calleeFunc = callBase->getCalledFunction()) { + if (Function *calleeFunc = callBase->getCalledFunction()) + { u32_t pos = getArgPosInCall(callBase, curValue); // for variable argument, conservatively collect all params if (calleeFunc->isVarArg()) pos = 0; - if (!calleeFunc->isDeclaration()) { + if (!calleeFunc->isDeclaration()) + { insertInferSitesOrPushWorklist(calleeFunc->getArg(pos)); } } } } - if (canUpdate) { + if (canUpdate) + { Set types; std::transform(infersites.begin(), infersites.end(), std::inserter(types, types.begin()), infersiteToType); @@ -310,7 +367,8 @@ const Type *ObjTypeInference::fwInferObjType(const Value *var) { } } const Type *type = _valueToType[var]; - if (type == nullptr) { + if (type == nullptr) + { type = defaultType(var); WARN_MSG("Using default type, trace ID is " + std::to_string(traceId) + ":" + dumpValueAndDbgInfo(var)); } @@ -323,11 +381,13 @@ const Type *ObjTypeInference::fwInferObjType(const Value *var) { * @param var * @return */ -Set &ObjTypeInference::bwfindAllocOfVar(const Value *var) { +Set &ObjTypeInference::bwfindAllocOfVar(const Value *var) +{ // consult cache auto tIt = _valueToAllocs.find(var); - if (tIt != _valueToAllocs.end()) { + if (tIt != _valueToAllocs.end()) + { return tIt->second; } @@ -335,7 +395,8 @@ Set &ObjTypeInference::bwfindAllocOfVar(const Value *var) { FILOWorkList workList; Set visited; workList.push({var, false}); - while (!workList.empty()) { + while (!workList.empty()) + { auto curPair = workList.pop(); if (visited.count(curPair)) continue; visited.insert(curPair); @@ -343,44 +404,66 @@ Set &ObjTypeInference::bwfindAllocOfVar(const Value *var) { bool canUpdate = curPair.second; Set sources; - auto insertAllocs = [&sources, &canUpdate](const Value *source) { + auto insertAllocs = [&sources, &canUpdate](const Value *source) + { if (canUpdate) sources.insert(source); }; - auto insertAllocsOrPushWorklist = [this, &sources, &workList, &canUpdate](const auto &pUser) { + auto insertAllocsOrPushWorklist = [this, &sources, &workList, &canUpdate](const auto &pUser) + { auto vIt = _valueToAllocs.find(pUser); - if (canUpdate) { - if (vIt != _valueToAllocs.end()) { + if (canUpdate) + { + if (vIt != _valueToAllocs.end()) + { sources.insert(vIt->second.begin(), vIt->second.end()); } - } else { + } + else + { if (vIt == _valueToAllocs.end()) workList.push({pUser, false}); } }; - if (!canUpdate && !_valueToAllocs.count(curValue)) { + if (!canUpdate && !_valueToAllocs.count(curValue)) + { workList.push({curValue, true}); } - if (isAlloc(curValue)) { + if (isAlloc(curValue)) + { insertAllocs(curValue); - } else if (const auto *bitCastInst = SVFUtil::dyn_cast(curValue)) { + } + else if (const auto *bitCastInst = SVFUtil::dyn_cast(curValue)) + { Value *prevVal = bitCastInst->getOperand(0); insertAllocsOrPushWorklist(prevVal); - } else if (const auto *phiNode = SVFUtil::dyn_cast(curValue)) { - for (u32_t i = 0; i < phiNode->getNumOperands(); ++i) { + } + else if (const auto *phiNode = SVFUtil::dyn_cast(curValue)) + { + for (u32_t i = 0; i < phiNode->getNumOperands(); ++i) + { insertAllocsOrPushWorklist(phiNode->getOperand(i)); } - } else if (const auto *loadInst = SVFUtil::dyn_cast(curValue)) { - for (const auto &use: loadInst->getPointerOperand()->uses()) { - if (const StoreInst *storeInst = SVFUtil::dyn_cast(use.getUser())) { - if (storeInst->getPointerOperand() == loadInst->getPointerOperand()) { + } + else if (const auto *loadInst = SVFUtil::dyn_cast(curValue)) + { + for (const auto &use: loadInst->getPointerOperand()->uses()) + { + if (const StoreInst *storeInst = SVFUtil::dyn_cast(use.getUser())) + { + if (storeInst->getPointerOperand() == loadInst->getPointerOperand()) + { insertAllocsOrPushWorklist(storeInst->getValueOperand()); } } } - } else if (const auto *argument = SVFUtil::dyn_cast(curValue)) { - for (const auto &use: argument->getParent()->uses()) { - if (const CallBase *callBase = SVFUtil::dyn_cast(use.getUser())) { + } + else if (const auto *argument = SVFUtil::dyn_cast(curValue)) + { + for (const auto &use: argument->getParent()->uses()) + { + if (const CallBase *callBase = SVFUtil::dyn_cast(use.getUser())) + { // skip function as parameter // e.g., call void @foo(%struct.ssl_ctx_st* %9, i32 (i8*, i32, i32, i8*)* @passwd_callback) if (callBase->getCalledFunction() != argument->getParent()) continue; @@ -388,10 +471,14 @@ Set &ObjTypeInference::bwfindAllocOfVar(const Value *var) { insertAllocsOrPushWorklist(callBase->getArgOperand(pos)); } } - } else if (const auto *callBase = SVFUtil::dyn_cast(curValue)) { + } + else if (const auto *callBase = SVFUtil::dyn_cast(curValue)) + { ABORT_IFNOT(!callBase->doesNotReturn(), "callbase does not return:" + dumpValueAndDbgInfo(callBase)); - if (Function *callee = callBase->getCalledFunction()) { - if (!callee->isDeclaration()) { + if (Function *callee = callBase->getCalledFunction()) + { + if (!callee->isDeclaration()) + { const SVFFunction *svfFunc = LLVMModuleSet::getLLVMModuleSet()->getSVFFunction(callee); const Value *pValue = LLVMModuleSet::getLLVMModuleSet()->getLLVMValue(svfFunc->getExitBB()->back()); const auto *retInst = SVFUtil::dyn_cast(pValue); @@ -400,18 +487,21 @@ Set &ObjTypeInference::bwfindAllocOfVar(const Value *var) { } } } - if (canUpdate) { + if (canUpdate) + { _valueToAllocs[curValue] = SVFUtil::move(sources); } } Set &srcs = _valueToAllocs[var]; - if (srcs.empty()) { + if (srcs.empty()) + { WARN_MSG("Cannot find allocation: " + dumpValueAndDbgInfo(var)); } return srcs; } -bool ObjTypeInference::isAlloc(const SVF::Value *val) { +bool ObjTypeInference::isAlloc(const SVF::Value *val) +{ return LLVMUtil::isObject(val); } @@ -419,19 +509,23 @@ bool ObjTypeInference::isAlloc(const SVF::Value *val) { * validate type inference * @param cs : stub malloc function with element number label */ -void ObjTypeInference::validateTypeCheck(const CallBase *cs) { - if (const Function *func = cs->getCalledFunction()) { - if (func->getName().find(TYPEMALLOC) != std::string::npos) { +void ObjTypeInference::validateTypeCheck(const CallBase *cs) +{ + if (const Function *func = cs->getCalledFunction()) + { + if (func->getName().find(TYPEMALLOC) != std::string::npos) + { const Type *objType = fwInferObjType(cs); const auto *pInt = - SVFUtil::dyn_cast(cs->getOperand(1)); + SVFUtil::dyn_cast(cs->getOperand(1)); assert(pInt && "the second argument is a integer"); u32_t iTyNum = objTyToNumFields(objType); if (iTyNum >= pInt->getZExtValue()) SVFUtil::outs() << SVFUtil::sucMsg("\t SUCCESS :") << dumpValueAndDbgInfo(cs) << SVFUtil::pasMsg(" TYPE: ") << dumpType(objType) << "\n"; - else { + else + { SVFUtil::errs() << SVFUtil::errMsg("\t FAILURE :") << ":" << dumpValueAndDbgInfo(cs) << " TYPE: " << dumpType(objType) << "\n"; abort(); @@ -440,7 +534,8 @@ void ObjTypeInference::validateTypeCheck(const CallBase *cs) { } } -void ObjTypeInference::typeSizeDiffTest(const PointerType *oPTy, const Type *iTy, const Value *val) { +void ObjTypeInference::typeSizeDiffTest(const PointerType *oPTy, const Type *iTy, const Value *val) +{ #if TYPE_DEBUG Type *oTy = getPtrElementType(oPTy); u32_t iTyNum = objTyToNumFields(iTy); @@ -453,7 +548,8 @@ void ObjTypeInference::typeSizeDiffTest(const PointerType *oPTy, const Type *iTy #endif } -u32_t ObjTypeInference::getArgPosInCall(const CallBase *callBase, const Value *arg) { +u32_t ObjTypeInference::getArgPosInCall(const CallBase *callBase, const Value *arg) +{ assert(callBase->hasArgument(arg) && "callInst does not have argument arg?"); auto it = std::find(callBase->arg_begin(), callBase->arg_end(), arg); assert(it != callBase->arg_end() && "Didn't find argument?"); @@ -461,11 +557,13 @@ u32_t ObjTypeInference::getArgPosInCall(const CallBase *callBase, const Value *a } -const Type *ObjTypeInference::selectLargestSizedType(Set &objTys) { +const Type *ObjTypeInference::selectLargestSizedType(Set &objTys) +{ if (objTys.empty()) return nullptr; // map type size to types from with key in descending order OrderedMap, std::greater> typeSzToTypes; - for (const Type *ty: objTys) { + for (const Type *ty: objTys) + { typeSzToTypes[objTyToNumFields(ty)].insert(ty); } assert(!typeSzToTypes.empty() && "typeSzToTypes cannot be empty"); @@ -475,11 +573,13 @@ const Type *ObjTypeInference::selectLargestSizedType(Set &objTys) return *largestTypes.begin(); } -u32_t ObjTypeInference::objTyToNumFields(const Type *objTy) { +u32_t ObjTypeInference::objTyToNumFields(const Type *objTy) +{ u32_t num = Options::MaxFieldLimit(); if (SVFUtil::isa(objTy)) num = getNumOfElements(objTy); - else if (const auto *st = SVFUtil::dyn_cast(objTy)) { + else if (const auto *st = SVFUtil::dyn_cast(objTy)) + { /// For an C++ class, it can have variant elements depending on the vtable size, /// Hence we only handle non-cpp-class object, the type of the cpp class is treated as default PointerType if (!classTyHasVTable(st)) @@ -494,44 +594,58 @@ u32_t ObjTypeInference::objTyToNumFields(const Type *objTy) { * @param thisPtr * @return */ -Set &ObjTypeInference::inferThisPtrClsName(const Value *thisPtr) { +Set &ObjTypeInference::inferThisPtrClsName(const Value *thisPtr) +{ auto it = _thisPtrClassNames.find(thisPtr); if (it != _thisPtrClassNames.end()) return it->second; Set names; - auto insertClassNames = [&names](Set &classNames) { + auto insertClassNames = [&names](Set &classNames) + { names.insert(classNames.begin(), classNames.end()); }; // backward find heap allocations or class name sources Set &vals = bwFindAllocOrClsNameSources(thisPtr); - for (const auto &val: vals) { + for (const auto &val: vals) + { if (val == thisPtr) continue; - if (const auto *func = SVFUtil::dyn_cast(val)) { + if (const auto *func = SVFUtil::dyn_cast(val)) + { // extract class name from function name Set classNames = extractClsNamesFromFunc(func); insertClassNames(classNames); - } else if (SVFUtil::isa(val)) { + } + else if (SVFUtil::isa(val)) + { // extract class name from instructions const Type *type = infersiteToType(val); const std::string &className = typeToClsName(type); - if (!className.empty()) { + if (!className.empty()) + { Set tgt{className}; insertClassNames(tgt); } - } else if (const auto *callBase = SVFUtil::dyn_cast(val)) { - if (const Function *callFunc = callBase->getCalledFunction()) { + } + else if (const auto *callBase = SVFUtil::dyn_cast(val)) + { + if (const Function *callFunc = callBase->getCalledFunction()) + { Set classNames = extractClsNamesFromFunc(callFunc); insertClassNames(classNames); - if (isDynCast(callFunc)) { + if (isDynCast(callFunc)) + { // dynamic cast Set tgt{extractClsNameFromDynCast(callBase)}; insertClassNames(tgt); - } else if (isNewAlloc(callFunc)) { + } + else if (isNewAlloc(callFunc)) + { // for heap allocation, we forward find class name sources Set& srcs = fwFindClsNameSources(callBase); - for (const auto &src: srcs) { + for (const auto &src: srcs) + { classNames = extractClsNamesFromFunc(src); insertClassNames(classNames); } @@ -549,11 +663,13 @@ Set &ObjTypeInference::inferThisPtrClsName(const Value *thisPtr) { * @param startValue * @return */ -Set &ObjTypeInference::bwFindAllocOrClsNameSources(const Value *startValue) { +Set &ObjTypeInference::bwFindAllocOrClsNameSources(const Value *startValue) +{ // consult cache auto tIt = _valueToAllocOrClsNameSources.find(startValue); - if (tIt != _valueToAllocOrClsNameSources.end()) { + if (tIt != _valueToAllocOrClsNameSources.end()) + { return tIt->second; } @@ -561,7 +677,8 @@ Set &ObjTypeInference::bwFindAllocOrClsNameSources(const Value *s FILOWorkList workList; Set visited; workList.push({startValue, false}); - while (!workList.empty()) { + while (!workList.empty()) + { auto curPair = workList.pop(); if (visited.count(curPair)) continue; visited.insert(curPair); @@ -569,59 +686,87 @@ Set &ObjTypeInference::bwFindAllocOrClsNameSources(const Value *s bool canUpdate = curPair.second; Set sources; - auto insertSource = [&sources, &canUpdate](const Value *source) { + auto insertSource = [&sources, &canUpdate](const Value *source) + { if (canUpdate) sources.insert(source); }; - auto insertSourcesOrPushWorklist = [this, &sources, &workList, &canUpdate](const auto &pUser) { + auto insertSourcesOrPushWorklist = [this, &sources, &workList, &canUpdate](const auto &pUser) + { auto vIt = _valueToAllocOrClsNameSources.find(pUser); - if (canUpdate) { - if (vIt != _valueToAllocOrClsNameSources.end() && !vIt->second.empty()) { + if (canUpdate) + { + if (vIt != _valueToAllocOrClsNameSources.end() && !vIt->second.empty()) + { sources.insert(vIt->second.begin(), vIt->second.end()); } - } else { + } + else + { if (vIt == _valueToAllocOrClsNameSources.end()) workList.push({pUser, false}); } }; - if (!canUpdate && !_valueToAllocOrClsNameSources.count(curValue)) { + if (!canUpdate && !_valueToAllocOrClsNameSources.count(curValue)) + { workList.push({curValue, true}); } // current inst reside in cpp self-inference function - if (const auto *inst = SVFUtil::dyn_cast(curValue)) { - if (const Function *foo = inst->getFunction()) { - if (isConstructor(foo) || isDestructor(foo) || isTemplateFunc(foo) || isDynCast(foo)) { + if (const auto *inst = SVFUtil::dyn_cast(curValue)) + { + if (const Function *foo = inst->getFunction()) + { + if (isConstructor(foo) || isDestructor(foo) || isTemplateFunc(foo) || isDynCast(foo)) + { insertSource(foo); - if (canUpdate) { + if (canUpdate) + { _valueToAllocOrClsNameSources[curValue] = sources; } continue; } } } - if (isAlloc(curValue) || isClsNameSource(curValue)) { + if (isAlloc(curValue) || isClsNameSource(curValue)) + { insertSource(curValue); - } else if (const auto *getElementPtrInst = SVFUtil::dyn_cast(curValue)) { + } + else if (const auto *getElementPtrInst = SVFUtil::dyn_cast(curValue)) + { insertSource(getElementPtrInst); insertSourcesOrPushWorklist(getElementPtrInst->getPointerOperand()); - } else if (const auto *bitCastInst = SVFUtil::dyn_cast(curValue)) { + } + else if (const auto *bitCastInst = SVFUtil::dyn_cast(curValue)) + { Value *prevVal = bitCastInst->getOperand(0); insertSourcesOrPushWorklist(prevVal); - } else if (const auto *phiNode = SVFUtil::dyn_cast(curValue)) { - for (u32_t i = 0; i < phiNode->getNumOperands(); ++i) { + } + else if (const auto *phiNode = SVFUtil::dyn_cast(curValue)) + { + for (u32_t i = 0; i < phiNode->getNumOperands(); ++i) + { insertSourcesOrPushWorklist(phiNode->getOperand(i)); } - } else if (const auto *loadInst = SVFUtil::dyn_cast(curValue)) { - for (const auto &use: loadInst->getPointerOperand()->uses()) { - if (const auto *storeInst = SVFUtil::dyn_cast(use.getUser())) { - if (storeInst->getPointerOperand() == loadInst->getPointerOperand()) { + } + else if (const auto *loadInst = SVFUtil::dyn_cast(curValue)) + { + for (const auto &use: loadInst->getPointerOperand()->uses()) + { + if (const auto *storeInst = SVFUtil::dyn_cast(use.getUser())) + { + if (storeInst->getPointerOperand() == loadInst->getPointerOperand()) + { insertSourcesOrPushWorklist(storeInst->getValueOperand()); } } } - } else if (const auto *argument = SVFUtil::dyn_cast(curValue)) { - for (const auto &use: argument->getParent()->uses()) { - if (const auto *callBase = SVFUtil::dyn_cast(use.getUser())) { + } + else if (const auto *argument = SVFUtil::dyn_cast(curValue)) + { + for (const auto &use: argument->getParent()->uses()) + { + if (const auto *callBase = SVFUtil::dyn_cast(use.getUser())) + { // skip function as parameter // e.g., call void @foo(%struct.ssl_ctx_st* %9, i32 (i8*, i32, i32, i8*)* @passwd_callback) if (callBase->getCalledFunction() != argument->getParent()) continue; @@ -629,10 +774,14 @@ Set &ObjTypeInference::bwFindAllocOrClsNameSources(const Value *s insertSourcesOrPushWorklist(callBase->getArgOperand(pos)); } } - } else if (const auto *callBase = SVFUtil::dyn_cast(curValue)) { + } + else if (const auto *callBase = SVFUtil::dyn_cast(curValue)) + { ABORT_IFNOT(!callBase->doesNotReturn(), "callbase does not return:" + dumpValueAndDbgInfo(callBase)); - if (Function *callee = callBase->getCalledFunction()) { - if (!callee->isDeclaration()) { + if (Function *callee = callBase->getCalledFunction()) + { + if (!callee->isDeclaration()) + { const SVFFunction *svfFunc = LLVMModuleSet::getLLVMModuleSet()->getSVFFunction(callee); const Value *pValue = LLVMModuleSet::getLLVMModuleSet()->getLLVMValue(svfFunc->getExitBB()->back()); const auto *retInst = SVFUtil::dyn_cast(pValue); @@ -641,33 +790,43 @@ Set &ObjTypeInference::bwFindAllocOrClsNameSources(const Value *s } } } - if (canUpdate) { + if (canUpdate) + { _valueToAllocOrClsNameSources[curValue] = sources; } } return _valueToAllocOrClsNameSources[startValue]; } -Set &ObjTypeInference::fwFindClsNameSources(const CallBase *alloc) { +Set &ObjTypeInference::fwFindClsNameSources(const CallBase *alloc) +{ // consult cache auto tIt = _allocToClsNameSources.find(alloc); - if (tIt != _allocToClsNameSources.end()) { + if (tIt != _allocToClsNameSources.end()) + { return tIt->second; } Set clsSources; // for heap allocation, we forward find class name sources - auto inferViaCppCall = [&clsSources](const CallBase *callBase) { + auto inferViaCppCall = [&clsSources](const CallBase *callBase) + { if (!callBase->getCalledFunction()) return; const Function *constructFoo = callBase->getCalledFunction(); clsSources.insert(constructFoo); }; - for (const auto &use: alloc->uses()) { - if (const auto *cppCall = SVFUtil::dyn_cast(use.getUser())) { + for (const auto &use: alloc->uses()) + { + if (const auto *cppCall = SVFUtil::dyn_cast(use.getUser())) + { inferViaCppCall(cppCall); - } else if (const auto *bitCastInst = SVFUtil::dyn_cast(use.getUser())) { - for (const auto &use2: bitCastInst->uses()) { - if (const auto *cppCall2 = SVFUtil::dyn_cast(use2.getUser())) { + } + else if (const auto *bitCastInst = SVFUtil::dyn_cast(use.getUser())) + { + for (const auto &use2: bitCastInst->uses()) + { + if (const auto *cppCall2 = SVFUtil::dyn_cast(use2.getUser())) + { inferViaCppCall(cppCall2); } }