#include "forwardanalyzer.h"
#include "astutils.h"
#include "settings.h"
#include "symboldatabase.h"

#include <functional>

struct ForwardTraversal {
    enum class Progress { Continue, Break, Skip };
    ValuePtr<ForwardAnalyzer> analyzer;
    const Settings* settings;

    std::pair<bool, bool> evalCond(const Token* tok) {
        std::vector<int> result = analyzer->evaluate(tok);
        bool checkThen = std::any_of(result.begin(), result.end(), [](int x) {
            return x;
        });
        bool checkElse = std::any_of(result.begin(), result.end(), [](int x) {
            return !x;
        });
        return std::make_pair(checkThen, checkElse);
    }

    template<class T, REQUIRES("T must be a Token class", std::is_convertible<T*, const Token*>)>
    Progress traverseTok(T* tok, std::function<Progress(T*)> f, bool traverseUnknown, T** out = nullptr) {
        if (Token::Match(tok, "asm|goto|continue|setjmp|longjmp"))
            return Progress::Break;
        else if (Token::Match(tok, "return|throw") || isEscapeFunction(tok, &settings->library)) {
            traverseRecursive(tok->astOperand1(), f, traverseUnknown);
            traverseRecursive(tok->astOperand2(), f, traverseUnknown);
            return Progress::Break;
        } else if (isUnevaluated(tok)) {
            if (out)
                *out = tok->link();
            return Progress::Skip;
        } else if (Token::Match(tok, "?|&&|%oror%")) {
            if (traverseConditional(tok, f, traverseUnknown) == Progress::Break)
                return Progress::Break;
            if (out)
                *out = nextAfterAstRightmostLeaf(tok);
            return Progress::Skip;
            // Skip lambdas
        } else if (T* lambdaEndToken = findLambdaEndToken(tok)) {
            if (checkScope(lambdaEndToken).isModified())
                return Progress::Break;
            if (out)
                *out = lambdaEndToken;
        } else {
            if (f(tok) == Progress::Break)
                return Progress::Break;
        }
        return Progress::Continue;
    }

    template<class T, REQUIRES("T must be a Token class", std::is_convertible<T*, const Token*>)>
    Progress traverseRecursive(T* tok, std::function<Progress(T*)> f, bool traverseUnknown) {
        if (!tok)
            return Progress::Continue;
        if (tok->astOperand1() && traverseRecursive(tok->astOperand1(), f, traverseUnknown) == Progress::Break)
            return Progress::Break;
        Progress p = traverseTok(tok, f, traverseUnknown);
        if (p == Progress::Break)
            return Progress::Break;
        if (p == Progress::Continue && traverseRecursive(tok->astOperand2(), f, traverseUnknown) == Progress::Break)
            return Progress::Break;
        return Progress::Continue;
    }

    template<class T, class F, REQUIRES("T must be a Token class", std::is_convertible<T*, const Token*>)>
    Progress traverseConditional(T* tok, F f, bool traverseUnknown) {
        if (Token::Match(tok, "?|&&|%oror%")) {
            T* condTok = tok->astOperand1();
            if (traverseRecursive(condTok, f, traverseUnknown) == Progress::Break)
                return Progress::Break;
            T* childTok = tok->astOperand2();
            bool checkThen, checkElse;
            std::tie(checkThen, checkElse) = evalCond(condTok);
            if (!checkThen && !checkElse) {
                // Stop if the value is conditional
                if (!traverseUnknown && analyzer->isConditional())
                    return Progress::Break;
                checkThen = true;
                checkElse = true;
            }
            if (Token::simpleMatch(childTok, ":")) {
                if (checkThen && traverseRecursive(childTok->astOperand1(), f, traverseUnknown) == Progress::Break)
                    return Progress::Break;
                if (checkElse && traverseRecursive(childTok->astOperand2(), f, traverseUnknown) == Progress::Break)
                    return Progress::Break;
            } else {
                if (!checkThen && Token::simpleMatch(tok, "&&"))
                    return Progress::Continue;
                if (!checkElse && Token::simpleMatch(tok, "||"))
                    return Progress::Continue;
                if (traverseRecursive(childTok, f, traverseUnknown) == Progress::Break)
                    return Progress::Break;
            }
        }
        return Progress::Continue;
    }

    Progress update(Token* tok) {
        ForwardAnalyzer::Action action = analyzer->analyze(tok);
        if (!action.isNone())
            analyzer->update(tok, action);
        if (action.isInvalid())
            return Progress::Break;
        return Progress::Continue;
    }

    Progress updateTok(Token* tok, Token** out = nullptr) {
        std::function<Progress(Token*)> f = [this](Token* tok2) {
            return update(tok2);
        };
        return traverseTok(tok, f, false, out);
    }

    Progress updateRecursive(Token* tok) {
        std::function<Progress(Token*)> f = [this](Token* tok2) {
            return update(tok2);
        };
        return traverseRecursive(tok, f, false);
    }

    template <class T>
    T* findRange(T* start, const Token* end, std::function<bool(ForwardAnalyzer::Action)> pred) {
        for (T* tok = start; tok && tok != end; tok = tok->next()) {
            ForwardAnalyzer::Action action = analyzer->analyze(tok);
            if (pred(action))
                return tok;
        }
        return nullptr;
    }

    ForwardAnalyzer::Action analyzeRecursive(const Token* start) {
        ForwardAnalyzer::Action result = ForwardAnalyzer::Action::None;
        std::function<Progress(const Token *)> f = [&](const Token* tok) {
            result = analyzer->analyze(tok);
            if (result.isModified() || result.isInconclusive())
                return Progress::Break;
            return Progress::Continue;
        };
        traverseRecursive(start, f, true);
        return result;
    }

    ForwardAnalyzer::Action analyzeRange(const Token* start, const Token* end) {
        ForwardAnalyzer::Action result = ForwardAnalyzer::Action::None;
        for (const Token* tok = start; tok && tok != end; tok = tok->next()) {
            ForwardAnalyzer::Action action = analyzer->analyze(tok);
            if (action.isModified() || action.isInconclusive())
                return action;
            result = action;
        }
        return result;
    }

    void forkScope(Token* endBlock, bool isModified = false) {
        if (analyzer->updateScope(endBlock, isModified)) {
            ForwardTraversal ft = *this;
            ft.updateRange(endBlock->link(), endBlock);
        }
    }

    static bool hasGoto(const Token* endBlock) {
        return Token::findsimplematch(endBlock->link(), "goto", endBlock);
    }

    bool isEscapeScope(const Token* endBlock, bool unknown = false) {
        const Token* ftok = nullptr;
        bool r = isReturnScope(endBlock, &settings->library, &ftok);
        if (!r && ftok)
            return unknown;
        return r;
    }

    enum class Status {
        None,
        Escaped,
        Modified,
        Inconclusive,
    };

    ForwardAnalyzer::Action analyzeScope(const Token* endBlock) {
        return analyzeRange(endBlock->link(), endBlock);
    }

    ForwardAnalyzer::Action checkScope(Token* endBlock) {
        ForwardAnalyzer::Action a = analyzeScope(endBlock);
        forkScope(endBlock, a.isModified());
        return a;
    }

    ForwardAnalyzer::Action checkScope(const Token* endBlock) {
        ForwardAnalyzer::Action a = analyzeScope(endBlock);
        return a;
    }

    Progress updateLoop(Token* endBlock, Token* condTok, Token* initTok = nullptr, Token* stepTok = nullptr) {
        ForwardAnalyzer::Action bodyAnalysis = analyzeScope(endBlock);
        ForwardAnalyzer::Action allAnalysis = bodyAnalysis;
        if (initTok)
            allAnalysis |= analyzeRecursive(initTok);
        if (stepTok)
            allAnalysis |= analyzeRecursive(stepTok);
        if (allAnalysis.isInconclusive()) {
            if (!analyzer->lowerToInconclusive())
                return Progress::Break;
        } else if (allAnalysis.isModified()) {
            if (!analyzer->lowerToPossible())
                return Progress::Break;
        }
        // Traverse condition after lowering
        if (condTok && updateRecursive(condTok) == Progress::Break)
            return Progress::Break;
        forkScope(endBlock, allAnalysis.isModified());
        if (bodyAnalysis.isModified()) {
            Token* writeTok = findRange(endBlock->link(), endBlock, std::mem_fn(&ForwardAnalyzer::Action::isModified));
            const Token* nextStatement = Token::findmatch(writeTok, ";|}", endBlock);
            if (!Token::Match(nextStatement, ";|} break ;"))
                return Progress::Break;
        } else {
            if (stepTok && updateRecursive(stepTok) == Progress::Break)
                return Progress::Break;
        }
        // TODO: Should we traverse the body?
        // updateRange(endBlock->link(), endBlock);
        return Progress::Continue;
    }

    Progress updateRange(Token* start, const Token* end) {
        for (Token* tok = start; tok && tok != end; tok = tok->next()) {
            Token* next = nullptr;

            // Evaluate RHS of assignment before LHS
            if (Token* assignTok = assignExpr(tok)) {
                if (updateRecursive(assignTok->astOperand2()) == Progress::Break)
                    return Progress::Break;
                if (updateRecursive(assignTok->astOperand1()) == Progress::Break)
                    return Progress::Break;
                if (update(assignTok) == Progress::Break)
                    return Progress::Break;
                tok = nextAfterAstRightmostLeaf(assignTok);
                if (!tok)
                    return Progress::Break;
            } else if (Token::simpleMatch(tok, "break")) {
                const Scope* scope = findBreakScope(tok->scope());
                if (!scope)
                    return Progress::Break;
                tok = skipTo(tok, scope->bodyEnd, end);
                if (!analyzer->lowerToPossible())
                    return Progress::Break;
                // TODO: Don't break, instead move to the outer scope
                if (!tok)
                    return Progress::Break;
            } else if (Token::Match(tok, "%name% :") || Token::simpleMatch(tok, "case")) {
                if (!analyzer->lowerToPossible())
                    return Progress::Break;
            } else if (Token::simpleMatch(tok, "}") && Token::Match(tok->link()->previous(), ")|else {")) {
                const bool inElse = Token::simpleMatch(tok->link()->previous(), "else {");
                const Token* condTok = getCondTokFromEnd(tok);
                if (!condTok)
                    return Progress::Break;
                if (!condTok->hasKnownIntValue()) {
                    if (!analyzer->lowerToPossible())
                        return Progress::Break;
                } else if (condTok->values().front().intvalue == !inElse) {
                    return Progress::Break;
                }
                analyzer->assume(condTok, !inElse);
                if (Token::simpleMatch(tok, "} else {"))
                    tok = tok->linkAt(2);
            } else if (Token::Match(tok, "if|while|for (") && Token::simpleMatch(tok->next()->link(), ") {")) {
                Token* endCond = tok->next()->link();
                Token* endBlock = endCond->next()->link();
                Token* condTok = getCondTok(tok);
                Token* initTok = getInitTok(tok);
                if (!condTok)
                    return Progress::Break;
                if (initTok && updateRecursive(initTok) == Progress::Break)
                    return Progress::Break;
                if (Token::Match(tok, "for|while (")) {
                    Token* stepTok = getStepTok(tok);
                    if (updateLoop(endBlock, condTok, initTok, stepTok) == Progress::Break)
                        return Progress::Break;
                    tok = endBlock;
                } else {
                    // Traverse condition
                    if (updateRecursive(condTok) == Progress::Break)
                        return Progress::Break;
                    // Check if condition is true or false
                    bool checkThen, checkElse;
                    std::tie(checkThen, checkElse) = evalCond(condTok);
                    ForwardAnalyzer::Action thenAction = ForwardAnalyzer::Action::None;
                    ForwardAnalyzer::Action elseAction = ForwardAnalyzer::Action::None;
                    bool hasElse = Token::simpleMatch(endBlock, "} else {");
                    bool bail = false;

                    // Traverse then block
                    bool returnThen = isEscapeScope(endBlock, true);
                    bool returnElse = false;
                    if (checkThen) {
                        if (updateRange(endCond->next(), endBlock) == Progress::Break)
                            return Progress::Break;
                    } else if (!checkElse) {
                        thenAction = checkScope(endBlock);
                        if (hasGoto(endBlock))
                            bail = true;
                    }
                    // Traverse else block
                    if (hasElse) {
                        returnElse = isEscapeScope(endBlock->linkAt(2), true);
                        if (checkElse) {
                            Progress result = updateRange(endBlock->tokAt(2), endBlock->linkAt(2));
                            if (result == Progress::Break)
                                return Progress::Break;
                        } else if (!checkThen) {
                            elseAction = checkScope(endBlock->linkAt(2));
                            if (hasGoto(endBlock))
                                bail = true;
                        }
                        tok = endBlock->linkAt(2);
                    } else {
                        tok = endBlock;
                    }
                    if (bail)
                        return Progress::Break;
                    if (returnThen && returnElse)
                        return Progress::Break;
                    else if (thenAction.isModified() && elseAction.isModified())
                        return Progress::Break;
                    else if ((returnThen || returnElse) && (thenAction.isModified() || elseAction.isModified()))
                        return Progress::Break;
                    // Conditional return
                    if (returnThen && !hasElse) {
                        if (checkThen) {
                            return Progress::Break;
                        } else {
                            if (analyzer->isConditional())
                                return Progress::Break;
                            analyzer->assume(condTok, false);
                        }
                    }
                    if (thenAction.isInconclusive() || elseAction.isInconclusive()) {
                        if (!analyzer->lowerToInconclusive())
                            return Progress::Break;
                    } else if (thenAction.isModified() || elseAction.isModified()) {
                        if (!hasElse && analyzer->isConditional())
                            return Progress::Break;
                        if (!analyzer->lowerToPossible())
                            return Progress::Break;
                        analyzer->assume(condTok, elseAction.isModified());
                    }
                }
            } else if (Token::simpleMatch(tok, "} else {")) {
                tok = tok->linkAt(2);
            } else if (Token::simpleMatch(tok, "do {")) {
                Token* endBlock = tok->next()->link();
                if (updateLoop(endBlock, nullptr) == Progress::Break)
                    return Progress::Break;
                tok = endBlock;
            } else if (Token::Match(tok, "assert|ASSERT (")) {
                const Token* condTok = tok->next()->astOperand2();
                bool checkThen, checkElse;
                std::tie(checkThen, checkElse) = evalCond(condTok);
                if (checkElse)
                    return Progress::Break;
                if (!checkThen)
                    analyzer->assume(condTok, true);
            } else if (Token::simpleMatch(tok, "switch (")) {
                if (updateRecursive(tok->next()->astOperand2()) == Progress::Break)
                    return Progress::Break;
                return Progress::Break;
            } else {
                if (updateTok(tok, &next) == Progress::Break)
                    return Progress::Break;
                if (next)
                    tok = next;
            }
            // Prevent infinite recursion
            if (tok->next() == start)
                break;
        }
        return Progress::Continue;
    }

    static bool isUnevaluated(const Token* tok) {
        if (Token::Match(tok->previous(), "sizeof|decltype ("))
            return true;
        return false;
    }

    static Token* assignExpr(Token* tok) {
        while (tok->astParent() && astIsLHS(tok)) {
            if (Token::Match(tok->astParent(), "%assign%"))
                return tok->astParent();
            tok = tok->astParent();
        }
        return nullptr;
    }

    static const Scope* findBreakScope(const Scope* scope) {
        while (scope && scope->type != Scope::eWhile && scope->type != Scope::eFor && scope->type != Scope::eSwitch)
            scope = scope->nestedIn;
        return scope;
    }

    static Token* skipTo(Token* tok, const Token* dest, const Token* end = nullptr) {
        if (end && dest->index() > end->index())
            return nullptr;
        int i = dest->index() - tok->index();
        if (i > 0)
            return tok->tokAt(dest->index() - tok->index());
        return nullptr;
    }

    static bool isConditional(const Token* tok) {
        const Token* parent = tok->astParent();
        while (parent && !Token::Match(parent, "%oror%|&&|:")) {
            tok = parent;
            parent = parent->astParent();
        }
        return parent && (parent->str() == ":" || parent->astOperand2() == tok);
    }

    static Token* getInitTok(Token* tok) {
        if (!tok)
            return nullptr;
        if (Token::Match(tok, "%name% ("))
            return getInitTok(tok->next());
        if (!Token::simpleMatch(tok, "("))
            return nullptr;
        if (!Token::simpleMatch(tok->astOperand2(), ";"))
            return nullptr;
        if (Token::simpleMatch(tok->astOperand2()->astOperand1(), ";"))
            return nullptr;
        return tok->astOperand2()->astOperand1();
    }

    static Token* getStepTok(Token* tok) {
        if (!tok)
            return nullptr;
        if (Token::Match(tok, "%name% ("))
            return getStepTok(tok->next());
        if (!Token::simpleMatch(tok, "("))
            return nullptr;
        if (!Token::simpleMatch(tok->astOperand2(), ";"))
            return nullptr;
        if (!Token::simpleMatch(tok->astOperand2()->astOperand2(), ";"))
            return nullptr;
        return tok->astOperand2()->astOperand2()->astOperand2();
    }

};

void valueFlowGenericForward(Token* start, const Token* end, const ValuePtr<ForwardAnalyzer>& fa, const Settings* settings)
{
    ForwardTraversal ft{fa, settings};
    ft.updateRange(start, end);
}