diff --git a/lib/clangimport.cpp b/lib/clangimport.cpp index 624c25582..c35227ea4 100644 --- a/lib/clangimport.cpp +++ b/lib/clangimport.cpp @@ -27,6 +27,7 @@ #include #include +static const std::string AccessSpecDecl = "AccessSpecDecl"; static const std::string ArraySubscriptExpr = "ArraySubscriptExpr"; static const std::string BinaryOperator = "BinaryOperator"; static const std::string BreakStmt = "BreakStmt"; @@ -158,6 +159,11 @@ static std::vector splitString(const std::string &line) return ret; } +static bool contains(const std::vector &haystack, const std::string &needle) +{ + return std::find(haystack.begin(), haystack.end(), needle) != haystack.end(); +} + namespace clangimport { struct Data { struct Decl { @@ -245,6 +251,8 @@ namespace clangimport { // "}" tokens that are not end-of-scope std::set mNotScope; + + std::map scopeAccessControl; private: void notFound(const std::string &addr) { auto it = mNotFound.find(addr); @@ -381,7 +389,7 @@ std::string clangimport::AstNode::getType(int index) const bool clangimport::AstNode::isDefinition() const { - return std::find(mExtTokens.begin(), mExtTokens.end(), "definition") != mExtTokens.end(); + return contains(mExtTokens, "definition"); } std::string clangimport::AstNode::getTemplateParameters() const @@ -557,11 +565,21 @@ Scope *clangimport::AstNode::createScope(TokenList *tokenList, Scope::ScopeType scope->classDef = def; scope->check = nestedIn->check; scope->bodyStart = addtoken(tokenList, "{"); + mData->scopeAccessControl[scope] = (scopeType == Scope::ScopeType::eClass) ? AccessControl::Private : AccessControl::Public; if (!children2.empty()) { tokenList->back()->scope(scope); for (AstNodePtr astNode: children2) { if (astNode->nodeType == "VisibilityAttr") continue; + if (astNode->nodeType == AccessSpecDecl) { + if (contains(astNode->mExtTokens, "private")) + mData->scopeAccessControl[scope] = AccessControl::Private; + else if (contains(astNode->mExtTokens, "protected")) + mData->scopeAccessControl[scope] = AccessControl::Protected; + else if (contains(astNode->mExtTokens, "public")) + mData->scopeAccessControl[scope] = AccessControl::Public; + continue; + } astNode->createTokens(tokenList); if (scopeType == Scope::ScopeType::eEnum) astNode->addtoken(tokenList, ","); @@ -571,6 +589,7 @@ Scope *clangimport::AstNode::createScope(TokenList *tokenList, Scope::ScopeType } scope->bodyEnd = addtoken(tokenList, "}"); Token::createMutualLinks(const_cast(scope->bodyStart), const_cast(scope->bodyEnd)); + mData->scopeAccessControl.erase(scope); return scope; } @@ -1140,10 +1159,10 @@ Token * clangimport::AstNode::createTokensCall(TokenList *tokenList) void clangimport::AstNode::createTokensFunctionDecl(TokenList *tokenList) { - const bool prev = (std::find(mExtTokens.begin(), mExtTokens.end(), "prev") != mExtTokens.end()); + const bool prev = contains(mExtTokens, "prev"); const bool hasBody = !children.empty() && children.back()->nodeType == CompoundStmt; - const bool isStatic = (std::find(mExtTokens.begin(), mExtTokens.end(), "static") != mExtTokens.end()); - const bool isInline = (std::find(mExtTokens.begin(), mExtTokens.end(), "inline") != mExtTokens.end()); + const bool isStatic = contains(mExtTokens, "static"); + const bool isInline = contains(mExtTokens, "inline"); const Token *startToken = nullptr; @@ -1181,6 +1200,12 @@ void clangimport::AstNode::createTokensFunctionDecl(TokenList *tokenList) Function * const function = const_cast(nameToken->function()); + if (!prev) { + auto accessControl = mData->scopeAccessControl.find(tokenList->back()->scope()); + if (accessControl != mData->scopeAccessControl.end()) + function->access = accessControl->second; + } + Scope *scope = nullptr; if (hasBody) { symbolDatabase->scopeList.push_back(Scope(nullptr, nullptr, nestedIn)); @@ -1241,7 +1266,7 @@ void clangimport::AstNode::createTokensFunctionDecl(TokenList *tokenList) bodyStart->link(bodyEnd); bodyEnd->link(bodyStart); } else { - if (nodeType == CXXConstructorDecl && (std::find(mExtTokens.begin(), mExtTokens.end(), "default") != mExtTokens.end())) { + if (nodeType == CXXConstructorDecl && contains(mExtTokens, "default")) { addtoken(tokenList, "="); addtoken(tokenList, "default"); } @@ -1252,7 +1277,7 @@ void clangimport::AstNode::createTokensFunctionDecl(TokenList *tokenList) void clangimport::AstNode::createTokensForCXXRecord(TokenList *tokenList) { - bool isStruct = (std::find(mExtTokens.begin(), mExtTokens.end(), "struct") != mExtTokens.end()); + bool isStruct = contains(mExtTokens, "struct"); Token * const classToken = addtoken(tokenList, isStruct ? "struct" : "class"); std::string className; if (mExtTokens[mExtTokens.size() - 2] == (isStruct?"struct":"class")) @@ -1279,7 +1304,8 @@ void clangimport::AstNode::createTokensForCXXRecord(TokenList *tokenList) child->nodeType == CXXDestructorDecl || child->nodeType == CXXMethodDecl || child->nodeType == FieldDecl || - child->nodeType == VarDecl) + child->nodeType == VarDecl || + child->nodeType == AccessSpecDecl) children2.push_back(child); } Scope *scope = createScope(tokenList, isStruct ? Scope::ScopeType::eStruct : Scope::ScopeType::eClass, children2, classToken); @@ -1298,7 +1324,7 @@ Token * clangimport::AstNode::createTokensVarDecl(TokenList *tokenList) { const std::string addr = mExtTokens.front(); const Token *startToken = nullptr; - if (std::find(mExtTokens.cbegin(), mExtTokens.cend(), "static") != mExtTokens.cend()) + if (contains(mExtTokens, "static")) startToken = addtoken(tokenList, "static"); int typeIndex = mExtTokens.size() - 1; while (typeIndex > 1 && std::isalpha(mExtTokens[typeIndex][0])) diff --git a/test/cli/test-clang-import.py b/test/cli/test-clang-import.py index 073404a5a..bee6bbeca 100644 --- a/test/cli/test-clang-import.py +++ b/test/cli/test-clang-import.py @@ -101,6 +101,15 @@ def test_symbol_database_6(): def test_symbol_database_7(): check_symbol_database('struct S {int x;}; void f(struct S *s) {}') +def test_symbol_database_class_access_1(): + check_symbol_database('class Fred { void foo ( ) {} } ;') + +def test_symbol_database_class_access_2(): + check_symbol_database('class Fred { protected: void foo ( ) {} } ;') + +def test_symbol_database_class_access_3(): + check_symbol_database('class Fred { public: void foo ( ) {} } ;') + def test_symbol_database_operator(): check_symbol_database('struct Fred { void operator=(int x); };')