#include "routing.hpp"

#include "bmcweb_config.h"

#include "error_messages.hpp"
#include "http_request.hpp"
#include "http_response.hpp"
#include "logging.hpp"
#include "utils/dbus_utils.hpp"

#include <boost/beast/ssl/ssl_stream.hpp>
#include <sdbusplus/unpack_properties.hpp>

namespace crow
{

bool Router::isUserPrivileged(
    Request& req, const std::shared_ptr<bmcweb::AsyncResp>& asyncResp,
    BaseRule& rule, const dbus::utility::DBusPropertiesMap& userInfoMap)
{
    std::string userRole{};
    const std::string* userRolePtr = nullptr;
    const bool* remoteUser = nullptr;
    const bool* passwordExpired = nullptr;

    const bool success = sdbusplus::unpackPropertiesNoThrow(
        redfish::dbus_utils::UnpackErrorPrinter(), userInfoMap, "UserPrivilege",
        userRolePtr, "RemoteUser", remoteUser, "UserPasswordExpired",
        passwordExpired);

    if (!success)
    {
        asyncResp->res.result(
            boost::beast::http::status::internal_server_error);
        return false;
    }

    if (userRolePtr != nullptr)
    {
        userRole = *userRolePtr;
        BMCWEB_LOG_DEBUG << "userName = " << req.session->username
                         << " userRole = " << *userRolePtr;
    }

    if (remoteUser == nullptr)
    {
        BMCWEB_LOG_ERROR << "RemoteUser property missing or wrong type";
        asyncResp->res.result(
            boost::beast::http::status::internal_server_error);
        return false;
    }
    bool expired = false;
    if (passwordExpired == nullptr)
    {
        if (!*remoteUser)
        {
            BMCWEB_LOG_ERROR << "UserPasswordExpired property is expected for"
                                " local user but is missing or wrong type";
            asyncResp->res.result(
                boost::beast::http::status::internal_server_error);
            return false;
        }
    }
    else
    {
        expired = *passwordExpired;
    }

    // Get the user's privileges from the role
    redfish::Privileges userPrivileges = redfish::getUserPrivileges(userRole);

    // Set isConfigureSelfOnly based on D-Bus results.  This
    // ignores the results from both pamAuthenticateUser and the
    // value from any previous use of this session.
    req.session->isConfigureSelfOnly = expired;

    // Modify privileges if isConfigureSelfOnly.
    if (req.session->isConfigureSelfOnly)
    {
        // Remove all privileges except ConfigureSelf
        userPrivileges =
            userPrivileges.intersection(redfish::Privileges{"ConfigureSelf"});
        BMCWEB_LOG_DEBUG << "Operation limited to ConfigureSelf";
    }

    if (!rule.checkPrivileges(userPrivileges))
    {
        asyncResp->res.result(boost::beast::http::status::forbidden);
        if (req.session->isConfigureSelfOnly)
        {
            redfish::messages::passwordChangeRequired(
                asyncResp->res, crow::utility::urlFromPieces(
                                    "redfish", "v1", "AccountService",
                                    "Accounts", req.session->username));
        }
        return false;
    }

    req.userRole = userRole;

    return true;
}

void Trie::optimizeNode(Node* node)
{
    for (size_t x : node->paramChildrens)
    {
        if (x == 0U)
        {
            continue;
        }
        Node* child = &nodes[x];
        optimizeNode(child);
    }
    if (node->children.empty())
    {
        return;
    }
    bool mergeWithChild = true;
    for (const Node::ChildMap::value_type& kv : node->children)
    {
        Node* child = &nodes[kv.second];
        if (!child->isSimpleNode())
        {
            mergeWithChild = false;
            break;
        }
    }
    if (mergeWithChild)
    {
        Node::ChildMap merged;
        for (const Node::ChildMap::value_type& kv : node->children)
        {
            Node* child = &nodes[kv.second];
            for (const Node::ChildMap::value_type& childKv : child->children)
            {
                merged[kv.first + childKv.first] = childKv.second;
            }
        }
        node->children = std::move(merged);
        optimizeNode(node);
    }
    else
    {
        for (const Node::ChildMap::value_type& kv : node->children)
        {
            Node* child = &nodes[kv.second];
            optimizeNode(child);
        }
    }
}

void Trie::findRouteIndexes(const std::string& reqUrl,
                            std::vector<unsigned>& routeIndexes,
                            const Node* node, unsigned pos) const
{
    if (node == nullptr)
    {
        node = head();
    }
    for (const Node::ChildMap::value_type& kv : node->children)
    {
        const std::string& fragment = kv.first;
        const Node* child = &nodes[kv.second];
        if (pos >= reqUrl.size())
        {
            if (child->ruleIndex != 0 && fragment != "/")
            {
                routeIndexes.push_back(child->ruleIndex);
            }
            findRouteIndexes(reqUrl, routeIndexes, child,
                             static_cast<unsigned>(pos + fragment.size()));
        }
        else
        {
            if (reqUrl.compare(pos, fragment.size(), fragment) == 0)
            {
                findRouteIndexes(reqUrl, routeIndexes, child,
                                 static_cast<unsigned>(pos + fragment.size()));
            }
        }
    }
}

std::pair<unsigned int, RoutingParams> Trie::find(std::string_view reqUrl,
                                                  const Node* node, size_t pos,
                                                  RoutingParams* params) const
{
    RoutingParams empty;
    if (params == nullptr)
    {
        params = &empty;
    }

    unsigned found{};
    RoutingParams matchParams;

    if (node == nullptr)
    {
        node = head();
    }
    if (pos == reqUrl.size())
    {
        return {node->ruleIndex, *params};
    }

    auto updateFound = [&found,
                        &matchParams](std::pair<unsigned, RoutingParams>& ret) {
        if (ret.first != 0U && (found == 0U || found > ret.first))
        {
            found = ret.first;
            matchParams = std::move(ret.second);
        }
    };

    if (node->paramChildrens[static_cast<size_t>(ParamType::INT)] != 0U)
    {
        char c = reqUrl[pos];
        if ((c >= '0' && c <= '9') || c == '+' || c == '-')
        {
            char* eptr = nullptr;
            errno = 0;
            long long int value = std::strtoll(reqUrl.data() + pos, &eptr, 10);
            if (errno != ERANGE && eptr != reqUrl.data() + pos)
            {
                params->intParams.push_back(value);
                std::pair<unsigned, RoutingParams> ret =
                    find(reqUrl,
                         &nodes[node->paramChildrens[static_cast<size_t>(
                             ParamType::INT)]],
                         static_cast<size_t>(eptr - reqUrl.data()), params);
                updateFound(ret);
                params->intParams.pop_back();
            }
        }
    }

    if (node->paramChildrens[static_cast<size_t>(ParamType::UINT)] != 0U)
    {
        char c = reqUrl[pos];
        if ((c >= '0' && c <= '9') || c == '+')
        {
            char* eptr = nullptr;
            errno = 0;
            unsigned long long int value =
                std::strtoull(reqUrl.data() + pos, &eptr, 10);
            if (errno != ERANGE && eptr != reqUrl.data() + pos)
            {
                params->uintParams.push_back(value);
                std::pair<unsigned, RoutingParams> ret =
                    find(reqUrl,
                         &nodes[node->paramChildrens[static_cast<size_t>(
                             ParamType::UINT)]],
                         static_cast<size_t>(eptr - reqUrl.data()), params);
                updateFound(ret);
                params->uintParams.pop_back();
            }
        }
    }

    if (node->paramChildrens[static_cast<size_t>(ParamType::DOUBLE)] != 0U)
    {
        char c = reqUrl[pos];
        if ((c >= '0' && c <= '9') || c == '+' || c == '-' || c == '.')
        {
            char* eptr = nullptr;
            errno = 0;
            double value = std::strtod(reqUrl.data() + pos, &eptr);
            if (errno != ERANGE && eptr != reqUrl.data() + pos)
            {
                params->doubleParams.push_back(value);
                std::pair<unsigned, RoutingParams> ret =
                    find(reqUrl,
                         &nodes[node->paramChildrens[static_cast<size_t>(
                             ParamType::DOUBLE)]],
                         static_cast<size_t>(eptr - reqUrl.data()), params);
                updateFound(ret);
                params->doubleParams.pop_back();
            }
        }
    }

    if (node->paramChildrens[static_cast<size_t>(ParamType::STRING)] != 0U)
    {
        size_t epos = pos;
        for (; epos < reqUrl.size(); epos++)
        {
            if (reqUrl[epos] == '/')
            {
                break;
            }
        }

        if (epos != pos)
        {
            params->stringParams.emplace_back(reqUrl.substr(pos, epos - pos));
            std::pair<unsigned, RoutingParams> ret =
                find(reqUrl,
                     &nodes[node->paramChildrens[static_cast<size_t>(
                         ParamType::STRING)]],
                     epos, params);
            updateFound(ret);
            params->stringParams.pop_back();
        }
    }

    if (node->paramChildrens[static_cast<size_t>(ParamType::PATH)] != 0U)
    {
        size_t epos = reqUrl.size();

        if (epos != pos)
        {
            params->stringParams.emplace_back(reqUrl.substr(pos, epos - pos));
            std::pair<unsigned, RoutingParams> ret =
                find(reqUrl,
                     &nodes[node->paramChildrens[static_cast<size_t>(
                         ParamType::PATH)]],
                     epos, params);
            updateFound(ret);
            params->stringParams.pop_back();
        }
    }

    for (const Node::ChildMap::value_type& kv : node->children)
    {
        const std::string& fragment = kv.first;
        const Node* child = &nodes[kv.second];

        if (reqUrl.compare(pos, fragment.size(), fragment) == 0)
        {
            std::pair<unsigned, RoutingParams> ret =
                find(reqUrl, child, pos + fragment.size(), params);
            updateFound(ret);
        }
    }

    return {found, matchParams};
}

void Trie::add(const std::string& url, unsigned int ruleIndex)
{
    size_t idx = 0;

    for (unsigned i = 0; i < url.size(); i++)
    {
        char c = url[i];
        if (c == '<')
        {
            const static std::array<std::pair<ParamType, std::string>, 7>
                paramTraits = {{
                    {ParamType::INT, "<int>"},
                    {ParamType::UINT, "<uint>"},
                    {ParamType::DOUBLE, "<float>"},
                    {ParamType::DOUBLE, "<double>"},
                    {ParamType::STRING, "<str>"},
                    {ParamType::STRING, "<string>"},
                    {ParamType::PATH, "<path>"},
                }};

            for (const std::pair<ParamType, std::string>& x : paramTraits)
            {
                if (url.compare(i, x.second.size(), x.second) == 0)
                {
                    size_t index = static_cast<size_t>(x.first);
                    if (nodes[idx].paramChildrens[index] == 0U)
                    {
                        unsigned newNodeIdx = newNode();
                        nodes[idx].paramChildrens[index] = newNodeIdx;
                    }
                    idx = nodes[idx].paramChildrens[index];
                    i += static_cast<unsigned>(x.second.size());
                    break;
                }
            }

            i--;
        }
        else
        {
            std::string piece(&c, 1);
            if (nodes[idx].children.count(piece) == 0U)
            {
                unsigned newNodeIdx = newNode();
                nodes[idx].children.emplace(piece, newNodeIdx);
            }
            idx = nodes[idx].children[piece];
        }
    }
    if (nodes[idx].ruleIndex != 0U)
    {
        throw std::runtime_error("handler already exists for " + url);
    }
    nodes[idx].ruleIndex = ruleIndex;
}

void Trie::debugNodePrint(Node* n, size_t level)

{
    for (size_t i = 0; i < static_cast<size_t>(ParamType::MAX); i++)
    {
        if (n->paramChildrens[i] != 0U)
        {
            BMCWEB_LOG_DEBUG << std::string(
                2U * level, ' ') /*<< "("<<n->paramChildrens[i]<<") "*/;
            switch (static_cast<ParamType>(i))
            {
                case ParamType::INT:
                    BMCWEB_LOG_DEBUG << "<int>";
                    break;
                case ParamType::UINT:
                    BMCWEB_LOG_DEBUG << "<uint>";
                    break;
                case ParamType::DOUBLE:
                    BMCWEB_LOG_DEBUG << "<float>";
                    break;
                case ParamType::STRING:
                    BMCWEB_LOG_DEBUG << "<str>";
                    break;
                case ParamType::PATH:
                    BMCWEB_LOG_DEBUG << "<path>";
                    break;
                case ParamType::MAX:
                    BMCWEB_LOG_DEBUG << "<ERROR>";
                    break;
            }

            debugNodePrint(&nodes[n->paramChildrens[i]], level + 1);
        }
    }
    for (const Node::ChildMap::value_type& kv : n->children)
    {
        BMCWEB_LOG_DEBUG << std::string(2U * level,
                                        ' ') /*<< "(" << kv.second << ") "*/
                         << kv.first;
        debugNodePrint(&nodes[kv.second], level + 1);
    }
}

void Router::internalAddRuleObject(const std::string& rule,
                                   BaseRule* ruleObject)
{
    if (ruleObject == nullptr)
    {
        return;
    }
    for (size_t method = 0, methodBit = 1; method <= methodNotAllowedIndex;
         method++, methodBit <<= 1)
    {
        if ((ruleObject->methodsBitfield & methodBit) > 0U)
        {
            perMethods[method].rules.emplace_back(ruleObject);
            perMethods[method].trie.add(
                rule,
                static_cast<unsigned>(perMethods[method].rules.size() - 1U));
            // directory case:
            //   request to `/about' url matches `/about/' rule
            if (rule.size() > 2 && rule.back() == '/')
            {
                perMethods[method].trie.add(
                    rule.substr(0, rule.size() - 1),
                    static_cast<unsigned>(perMethods[method].rules.size() - 1));
            }
        }
    }
}

void Router::validate()
{
    for (std::unique_ptr<BaseRule>& rule : allRules)
    {
        if (rule)
        {
            std::unique_ptr<BaseRule> upgraded = rule->upgrade();
            if (upgraded)
            {
                rule = std::move(upgraded);
            }
            rule->validate();
            internalAddRuleObject(rule->rule, rule.get());
        }
    }
    for (PerMethod& perMethod : perMethods)
    {
        perMethod.trie.validate();
    }
}

Router::FindRoute Router::findRouteByIndex(std::string_view url,
                                           size_t index) const
{
    FindRoute route;
    if (index >= perMethods.size())
    {
        BMCWEB_LOG_CRITICAL << "Bad index???";
        return route;
    }
    const PerMethod& perMethod = perMethods[index];
    std::pair<unsigned, RoutingParams> found = perMethod.trie.find(url);
    if (found.first >= perMethod.rules.size())
    {
        throw std::runtime_error("Trie internal structure corrupted!");
    }
    // Found a 404 route, switch that in
    if (found.first != 0U)
    {
        route.rule = perMethod.rules[found.first];
        route.params = std::move(found.second);
    }
    return route;
}

Router::FindRouteResponse Router::findRoute(Request& req) const
{
    FindRouteResponse findRoute;

    std::optional<HttpVerb> verb = httpVerbFromBoost(req.method());
    if (!verb)
    {
        return findRoute;
    }
    size_t reqMethodIndex = static_cast<size_t>(*verb);
    // Check to see if this url exists at any verb
    for (size_t perMethodIndex = 0; perMethodIndex <= maxVerbIndex;
         perMethodIndex++)
    {
        // Make sure it's safe to deference the array at that index
        static_assert(maxVerbIndex < std::tuple_size_v<decltype(perMethods)>);
        FindRoute route =
            findRouteByIndex(req.url().encoded_path(), perMethodIndex);
        if (route.rule == nullptr)
        {
            continue;
        }
        if (!findRoute.allowHeader.empty())
        {
            findRoute.allowHeader += ", ";
        }
        HttpVerb thisVerb = static_cast<HttpVerb>(perMethodIndex);
        findRoute.allowHeader += httpVerbToString(thisVerb);
        if (perMethodIndex == reqMethodIndex)
        {
            findRoute.route = route;
        }
    }
    return findRoute;
}

void Router::handle(Request& req,
                    const std::shared_ptr<bmcweb::AsyncResp>& asyncResp)
{
    std::optional<HttpVerb> verb = httpVerbFromBoost(req.method());
    if (!verb || static_cast<size_t>(*verb) >= perMethods.size())
    {
        asyncResp->res.result(boost::beast::http::status::not_found);
        return;
    }

    FindRouteResponse foundRoute = findRoute(req);

    if (foundRoute.route.rule == nullptr)
    {
        // Couldn't find a normal route with any verb, try looking for a 404
        // route
        if (foundRoute.allowHeader.empty())
        {
            foundRoute.route =
                findRouteByIndex(req.url().encoded_path(), notFoundIndex);
        }
        else
        {
            // See if we have a method not allowed (405) handler
            foundRoute.route = findRouteByIndex(req.url().encoded_path(),
                                                methodNotAllowedIndex);
        }
    }

    // Fill in the allow header if it's valid
    if (!foundRoute.allowHeader.empty())
    {

        asyncResp->res.addHeader(boost::beast::http::field::allow,
                                 foundRoute.allowHeader);
    }

    // If we couldn't find a real route or a 404 route, return a generic
    // response
    if (foundRoute.route.rule == nullptr)
    {
        if (foundRoute.allowHeader.empty())
        {
            asyncResp->res.result(boost::beast::http::status::not_found);
        }
        else
        {
            asyncResp->res.result(
                boost::beast::http::status::method_not_allowed);
        }
        return;
    }

    BaseRule& rule = *foundRoute.route.rule;
    RoutingParams params = std::move(foundRoute.route.params);

    BMCWEB_LOG_INFO << "Matched rule '" << rule.rule << "' "
                    << static_cast<uint32_t>(*verb) << " / "
                    << rule.getMethods();

#ifdef BMCWEB_ENABLE_GRPC
    if (!req.fromGrpc ||
        (insecureDisableGrpcRedfishAuthz && req.authzState.peer_authenticated))
    {
        rule.handle(req, asyncResp, params);
        return;
    }

    // Start authorization latency timing
    grpc_monotonic_clock::time_point start_time = grpc_monotonic_clock::now();
    // Handles dynamic fine-grained authorization
    std::string_view url_str = {req.url().encoded_path().data(),
                                req.url().encoded_path().size()};
    grpc::Status status =
        ::milotic::authz::BmcWebAuthorizerSingleton::GetInstance().Authorize(
            url_str, req.method(), req.authzState);

    // Finish authorization latency timing
    grpc_redfish::int64_duration milliseconds =
        std::chrono::duration_cast<grpc_redfish::int64_duration>(
            grpc_monotonic_clock::now() - start_time);
    asyncResp->res.grpcstats().authz_latency_total_ms += milliseconds;
    // grpc server code is measuring the time in bmcweb code as processing
    // time, so deduct the authorization timing from the processing latency
    asyncResp->res.grpcstats().processing_latency_total_ms -= milliseconds;

    if (!status.ok())
    {
        asyncResp->res.addHeader("OData-Version", "4.0");
        LOG(WARNING) << "Authorization failure at " << req.url() << ": "
                     << status.error_message();
        ::redfish::messages::resourceAtUriUnauthorized(
            asyncResp->res, req.url(), status.error_message());
        asyncResp->res.grpcstats().total_authorized_fail_count++;
        return;
    }

    asyncResp->res.grpcstats().total_authorized_count++;
    rule.handle(req, asyncResp, params);
#else
    validatePrivilege(req, asyncResp, rule,
                      [&rule, asyncResp, params](Request& thisReq) mutable {
        rule.handle(thisReq, asyncResp, params);
    });
#endif
}

void Router::debugPrint()
{
    for (size_t i = 0; i < perMethods.size(); i++)
    {
        BMCWEB_LOG_DEBUG << boost::beast::http::to_string(
            static_cast<boost::beast::http::verb>(i));
        perMethods[i].trie.debugPrint();
    }
}

std::vector<const std::string*> Router::getRoutes(const std::string& parent)

{
    std::vector<const std::string*> ret;

    for (const PerMethod& pm : perMethods)
    {
        std::vector<unsigned> x;
        pm.trie.findRouteIndexes(parent, x);
        for (unsigned index : x)
        {
            ret.push_back(&pm.rules[index]->rule);
        }
    }
    return ret;
}

void Router::throwHandlerNotFound(std::string_view url,
                                  boost::beast::http::verb boostVerb)
{

    std::string error = "handler doesn't exist for ";
    error += url;
    error += ", method ";
    error += boostVerbToString(boostVerb);
    throw std::runtime_error(error);
}

} // namespace crow