Clang import: Handle AccessSpecDecl

This commit is contained in:
Daniel Marjamäki 2020-11-08 17:02:19 +01:00
parent 4a057c1813
commit 1c742b7995
2 changed files with 43 additions and 8 deletions

View File

@ -27,6 +27,7 @@
#include <vector>
#include <iostream>
static const std::string AccessSpecDecl = "AccessSpecDecl";
static const std::string ArraySubscriptExpr = "ArraySubscriptExpr";
static const std::string BinaryOperator = "BinaryOperator";
static const std::string BreakStmt = "BreakStmt";
@ -158,6 +159,11 @@ static std::vector<std::string> splitString(const std::string &line)
return ret;
}
static bool contains(const std::vector<std::string> &haystack, const std::string &needle)
{
return std::find(haystack.begin(), haystack.end(), needle) != haystack.end();
}
namespace clangimport {
struct Data {
struct Decl {
@ -245,6 +251,8 @@ namespace clangimport {
// "}" tokens that are not end-of-scope
std::set<Token *> mNotScope;
std::map<const Scope *, AccessControl> scopeAccessControl;
private:
void notFound(const std::string &addr) {
auto it = mNotFound.find(addr);
@ -381,7 +389,7 @@ std::string clangimport::AstNode::getType(int index) const
bool clangimport::AstNode::isDefinition() const
{
return std::find(mExtTokens.begin(), mExtTokens.end(), "definition") != mExtTokens.end();
return contains(mExtTokens, "definition");
}
std::string clangimport::AstNode::getTemplateParameters() const
@ -557,11 +565,21 @@ Scope *clangimport::AstNode::createScope(TokenList *tokenList, Scope::ScopeType
scope->classDef = def;
scope->check = nestedIn->check;
scope->bodyStart = addtoken(tokenList, "{");
mData->scopeAccessControl[scope] = (scopeType == Scope::ScopeType::eClass) ? AccessControl::Private : AccessControl::Public;
if (!children2.empty()) {
tokenList->back()->scope(scope);
for (AstNodePtr astNode: children2) {
if (astNode->nodeType == "VisibilityAttr")
continue;
if (astNode->nodeType == AccessSpecDecl) {
if (contains(astNode->mExtTokens, "private"))
mData->scopeAccessControl[scope] = AccessControl::Private;
else if (contains(astNode->mExtTokens, "protected"))
mData->scopeAccessControl[scope] = AccessControl::Protected;
else if (contains(astNode->mExtTokens, "public"))
mData->scopeAccessControl[scope] = AccessControl::Public;
continue;
}
astNode->createTokens(tokenList);
if (scopeType == Scope::ScopeType::eEnum)
astNode->addtoken(tokenList, ",");
@ -571,6 +589,7 @@ Scope *clangimport::AstNode::createScope(TokenList *tokenList, Scope::ScopeType
}
scope->bodyEnd = addtoken(tokenList, "}");
Token::createMutualLinks(const_cast<Token*>(scope->bodyStart), const_cast<Token*>(scope->bodyEnd));
mData->scopeAccessControl.erase(scope);
return scope;
}
@ -1140,10 +1159,10 @@ Token * clangimport::AstNode::createTokensCall(TokenList *tokenList)
void clangimport::AstNode::createTokensFunctionDecl(TokenList *tokenList)
{
const bool prev = (std::find(mExtTokens.begin(), mExtTokens.end(), "prev") != mExtTokens.end());
const bool prev = contains(mExtTokens, "prev");
const bool hasBody = !children.empty() && children.back()->nodeType == CompoundStmt;
const bool isStatic = (std::find(mExtTokens.begin(), mExtTokens.end(), "static") != mExtTokens.end());
const bool isInline = (std::find(mExtTokens.begin(), mExtTokens.end(), "inline") != mExtTokens.end());
const bool isStatic = contains(mExtTokens, "static");
const bool isInline = contains(mExtTokens, "inline");
const Token *startToken = nullptr;
@ -1181,6 +1200,12 @@ void clangimport::AstNode::createTokensFunctionDecl(TokenList *tokenList)
Function * const function = const_cast<Function*>(nameToken->function());
if (!prev) {
auto accessControl = mData->scopeAccessControl.find(tokenList->back()->scope());
if (accessControl != mData->scopeAccessControl.end())
function->access = accessControl->second;
}
Scope *scope = nullptr;
if (hasBody) {
symbolDatabase->scopeList.push_back(Scope(nullptr, nullptr, nestedIn));
@ -1241,7 +1266,7 @@ void clangimport::AstNode::createTokensFunctionDecl(TokenList *tokenList)
bodyStart->link(bodyEnd);
bodyEnd->link(bodyStart);
} else {
if (nodeType == CXXConstructorDecl && (std::find(mExtTokens.begin(), mExtTokens.end(), "default") != mExtTokens.end())) {
if (nodeType == CXXConstructorDecl && contains(mExtTokens, "default")) {
addtoken(tokenList, "=");
addtoken(tokenList, "default");
}
@ -1252,7 +1277,7 @@ void clangimport::AstNode::createTokensFunctionDecl(TokenList *tokenList)
void clangimport::AstNode::createTokensForCXXRecord(TokenList *tokenList)
{
bool isStruct = (std::find(mExtTokens.begin(), mExtTokens.end(), "struct") != mExtTokens.end());
bool isStruct = contains(mExtTokens, "struct");
Token * const classToken = addtoken(tokenList, isStruct ? "struct" : "class");
std::string className;
if (mExtTokens[mExtTokens.size() - 2] == (isStruct?"struct":"class"))
@ -1279,7 +1304,8 @@ void clangimport::AstNode::createTokensForCXXRecord(TokenList *tokenList)
child->nodeType == CXXDestructorDecl ||
child->nodeType == CXXMethodDecl ||
child->nodeType == FieldDecl ||
child->nodeType == VarDecl)
child->nodeType == VarDecl ||
child->nodeType == AccessSpecDecl)
children2.push_back(child);
}
Scope *scope = createScope(tokenList, isStruct ? Scope::ScopeType::eStruct : Scope::ScopeType::eClass, children2, classToken);
@ -1298,7 +1324,7 @@ Token * clangimport::AstNode::createTokensVarDecl(TokenList *tokenList)
{
const std::string addr = mExtTokens.front();
const Token *startToken = nullptr;
if (std::find(mExtTokens.cbegin(), mExtTokens.cend(), "static") != mExtTokens.cend())
if (contains(mExtTokens, "static"))
startToken = addtoken(tokenList, "static");
int typeIndex = mExtTokens.size() - 1;
while (typeIndex > 1 && std::isalpha(mExtTokens[typeIndex][0]))

View File

@ -101,6 +101,15 @@ def test_symbol_database_6():
def test_symbol_database_7():
check_symbol_database('struct S {int x;}; void f(struct S *s) {}')
def test_symbol_database_class_access_1():
check_symbol_database('class Fred { void foo ( ) {} } ;')
def test_symbol_database_class_access_2():
check_symbol_database('class Fred { protected: void foo ( ) {} } ;')
def test_symbol_database_class_access_3():
check_symbol_database('class Fred { public: void foo ( ) {} } ;')
def test_symbol_database_operator():
check_symbol_database('struct Fred { void operator=(int x); };')