Clang import: Handle AccessSpecDecl

This commit is contained in:
Daniel Marjamäki 2020-11-08 17:02:19 +01:00
parent 4a057c1813
commit 1c742b7995
2 changed files with 43 additions and 8 deletions

View File

@ -27,6 +27,7 @@
#include <vector> #include <vector>
#include <iostream> #include <iostream>
static const std::string AccessSpecDecl = "AccessSpecDecl";
static const std::string ArraySubscriptExpr = "ArraySubscriptExpr"; static const std::string ArraySubscriptExpr = "ArraySubscriptExpr";
static const std::string BinaryOperator = "BinaryOperator"; static const std::string BinaryOperator = "BinaryOperator";
static const std::string BreakStmt = "BreakStmt"; static const std::string BreakStmt = "BreakStmt";
@ -158,6 +159,11 @@ static std::vector<std::string> splitString(const std::string &line)
return ret; return ret;
} }
static bool contains(const std::vector<std::string> &haystack, const std::string &needle)
{
return std::find(haystack.begin(), haystack.end(), needle) != haystack.end();
}
namespace clangimport { namespace clangimport {
struct Data { struct Data {
struct Decl { struct Decl {
@ -245,6 +251,8 @@ namespace clangimport {
// "}" tokens that are not end-of-scope // "}" tokens that are not end-of-scope
std::set<Token *> mNotScope; std::set<Token *> mNotScope;
std::map<const Scope *, AccessControl> scopeAccessControl;
private: private:
void notFound(const std::string &addr) { void notFound(const std::string &addr) {
auto it = mNotFound.find(addr); auto it = mNotFound.find(addr);
@ -381,7 +389,7 @@ std::string clangimport::AstNode::getType(int index) const
bool clangimport::AstNode::isDefinition() const bool clangimport::AstNode::isDefinition() const
{ {
return std::find(mExtTokens.begin(), mExtTokens.end(), "definition") != mExtTokens.end(); return contains(mExtTokens, "definition");
} }
std::string clangimport::AstNode::getTemplateParameters() const std::string clangimport::AstNode::getTemplateParameters() const
@ -557,11 +565,21 @@ Scope *clangimport::AstNode::createScope(TokenList *tokenList, Scope::ScopeType
scope->classDef = def; scope->classDef = def;
scope->check = nestedIn->check; scope->check = nestedIn->check;
scope->bodyStart = addtoken(tokenList, "{"); scope->bodyStart = addtoken(tokenList, "{");
mData->scopeAccessControl[scope] = (scopeType == Scope::ScopeType::eClass) ? AccessControl::Private : AccessControl::Public;
if (!children2.empty()) { if (!children2.empty()) {
tokenList->back()->scope(scope); tokenList->back()->scope(scope);
for (AstNodePtr astNode: children2) { for (AstNodePtr astNode: children2) {
if (astNode->nodeType == "VisibilityAttr") if (astNode->nodeType == "VisibilityAttr")
continue; continue;
if (astNode->nodeType == AccessSpecDecl) {
if (contains(astNode->mExtTokens, "private"))
mData->scopeAccessControl[scope] = AccessControl::Private;
else if (contains(astNode->mExtTokens, "protected"))
mData->scopeAccessControl[scope] = AccessControl::Protected;
else if (contains(astNode->mExtTokens, "public"))
mData->scopeAccessControl[scope] = AccessControl::Public;
continue;
}
astNode->createTokens(tokenList); astNode->createTokens(tokenList);
if (scopeType == Scope::ScopeType::eEnum) if (scopeType == Scope::ScopeType::eEnum)
astNode->addtoken(tokenList, ","); astNode->addtoken(tokenList, ",");
@ -571,6 +589,7 @@ Scope *clangimport::AstNode::createScope(TokenList *tokenList, Scope::ScopeType
} }
scope->bodyEnd = addtoken(tokenList, "}"); scope->bodyEnd = addtoken(tokenList, "}");
Token::createMutualLinks(const_cast<Token*>(scope->bodyStart), const_cast<Token*>(scope->bodyEnd)); Token::createMutualLinks(const_cast<Token*>(scope->bodyStart), const_cast<Token*>(scope->bodyEnd));
mData->scopeAccessControl.erase(scope);
return scope; return scope;
} }
@ -1140,10 +1159,10 @@ Token * clangimport::AstNode::createTokensCall(TokenList *tokenList)
void clangimport::AstNode::createTokensFunctionDecl(TokenList *tokenList) void clangimport::AstNode::createTokensFunctionDecl(TokenList *tokenList)
{ {
const bool prev = (std::find(mExtTokens.begin(), mExtTokens.end(), "prev") != mExtTokens.end()); const bool prev = contains(mExtTokens, "prev");
const bool hasBody = !children.empty() && children.back()->nodeType == CompoundStmt; const bool hasBody = !children.empty() && children.back()->nodeType == CompoundStmt;
const bool isStatic = (std::find(mExtTokens.begin(), mExtTokens.end(), "static") != mExtTokens.end()); const bool isStatic = contains(mExtTokens, "static");
const bool isInline = (std::find(mExtTokens.begin(), mExtTokens.end(), "inline") != mExtTokens.end()); const bool isInline = contains(mExtTokens, "inline");
const Token *startToken = nullptr; const Token *startToken = nullptr;
@ -1181,6 +1200,12 @@ void clangimport::AstNode::createTokensFunctionDecl(TokenList *tokenList)
Function * const function = const_cast<Function*>(nameToken->function()); Function * const function = const_cast<Function*>(nameToken->function());
if (!prev) {
auto accessControl = mData->scopeAccessControl.find(tokenList->back()->scope());
if (accessControl != mData->scopeAccessControl.end())
function->access = accessControl->second;
}
Scope *scope = nullptr; Scope *scope = nullptr;
if (hasBody) { if (hasBody) {
symbolDatabase->scopeList.push_back(Scope(nullptr, nullptr, nestedIn)); symbolDatabase->scopeList.push_back(Scope(nullptr, nullptr, nestedIn));
@ -1241,7 +1266,7 @@ void clangimport::AstNode::createTokensFunctionDecl(TokenList *tokenList)
bodyStart->link(bodyEnd); bodyStart->link(bodyEnd);
bodyEnd->link(bodyStart); bodyEnd->link(bodyStart);
} else { } else {
if (nodeType == CXXConstructorDecl && (std::find(mExtTokens.begin(), mExtTokens.end(), "default") != mExtTokens.end())) { if (nodeType == CXXConstructorDecl && contains(mExtTokens, "default")) {
addtoken(tokenList, "="); addtoken(tokenList, "=");
addtoken(tokenList, "default"); addtoken(tokenList, "default");
} }
@ -1252,7 +1277,7 @@ void clangimport::AstNode::createTokensFunctionDecl(TokenList *tokenList)
void clangimport::AstNode::createTokensForCXXRecord(TokenList *tokenList) void clangimport::AstNode::createTokensForCXXRecord(TokenList *tokenList)
{ {
bool isStruct = (std::find(mExtTokens.begin(), mExtTokens.end(), "struct") != mExtTokens.end()); bool isStruct = contains(mExtTokens, "struct");
Token * const classToken = addtoken(tokenList, isStruct ? "struct" : "class"); Token * const classToken = addtoken(tokenList, isStruct ? "struct" : "class");
std::string className; std::string className;
if (mExtTokens[mExtTokens.size() - 2] == (isStruct?"struct":"class")) if (mExtTokens[mExtTokens.size() - 2] == (isStruct?"struct":"class"))
@ -1279,7 +1304,8 @@ void clangimport::AstNode::createTokensForCXXRecord(TokenList *tokenList)
child->nodeType == CXXDestructorDecl || child->nodeType == CXXDestructorDecl ||
child->nodeType == CXXMethodDecl || child->nodeType == CXXMethodDecl ||
child->nodeType == FieldDecl || child->nodeType == FieldDecl ||
child->nodeType == VarDecl) child->nodeType == VarDecl ||
child->nodeType == AccessSpecDecl)
children2.push_back(child); children2.push_back(child);
} }
Scope *scope = createScope(tokenList, isStruct ? Scope::ScopeType::eStruct : Scope::ScopeType::eClass, children2, classToken); Scope *scope = createScope(tokenList, isStruct ? Scope::ScopeType::eStruct : Scope::ScopeType::eClass, children2, classToken);
@ -1298,7 +1324,7 @@ Token * clangimport::AstNode::createTokensVarDecl(TokenList *tokenList)
{ {
const std::string addr = mExtTokens.front(); const std::string addr = mExtTokens.front();
const Token *startToken = nullptr; const Token *startToken = nullptr;
if (std::find(mExtTokens.cbegin(), mExtTokens.cend(), "static") != mExtTokens.cend()) if (contains(mExtTokens, "static"))
startToken = addtoken(tokenList, "static"); startToken = addtoken(tokenList, "static");
int typeIndex = mExtTokens.size() - 1; int typeIndex = mExtTokens.size() - 1;
while (typeIndex > 1 && std::isalpha(mExtTokens[typeIndex][0])) while (typeIndex > 1 && std::isalpha(mExtTokens[typeIndex][0]))

View File

@ -101,6 +101,15 @@ def test_symbol_database_6():
def test_symbol_database_7(): def test_symbol_database_7():
check_symbol_database('struct S {int x;}; void f(struct S *s) {}') check_symbol_database('struct S {int x;}; void f(struct S *s) {}')
def test_symbol_database_class_access_1():
check_symbol_database('class Fred { void foo ( ) {} } ;')
def test_symbol_database_class_access_2():
check_symbol_database('class Fred { protected: void foo ( ) {} } ;')
def test_symbol_database_class_access_3():
check_symbol_database('class Fred { public: void foo ( ) {} } ;')
def test_symbol_database_operator(): def test_symbol_database_operator():
check_symbol_database('struct Fred { void operator=(int x); };') check_symbol_database('struct Fred { void operator=(int x); };')