#include "routing.hpp"

#include <array>
#include <cerrno>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <memory>
#include <optional>
#include <stdexcept>
#include <string>
#include <string_view>
#include <tuple>
#include <utility>
#include <vector>

#include "absl/log/log.h"  // NOLINT(misc-include-cleaner)
#include "boost/beast/ssl/ssl_stream.hpp"  // NOLINT
#include "bmcweb_config.h"  // NOLINT(misc-include-cleaner)
#include "common.hpp"
#include "http_request.hpp"
#include "http_response.hpp"
#include "logging.hpp"
#include "utility.hpp"
#include "verb.hpp"
#include "async_resp.hpp"
#include "dbus_utility.hpp"
#include "error_messages.hpp"
#include "privileges.hpp"
#include "dbus_utils.hpp"
#include "grpcpp/support/status.h"  // NOLINT(misc-include-cleaner)
#include "bmcweb_authorizer_singleton.h"  // NOLINT(misc-include-cleaner)
#include "sdbusplus/unpack_properties.hpp"

namespace crow {

bool Router::isUserPrivileged(
    Request& req, const std::shared_ptr<bmcweb::AsyncResp>& asyncResp,
    BaseRule& rule, const dbus::utility::DBusPropertiesMap& userInfoMap) {
  std::string user_role{};
  const std::string* user_role_ptr = nullptr;
  const bool* remote_user = nullptr;
  const bool* password_expired = nullptr;

  const bool success = sdbusplus::unpackPropertiesNoThrow(
      redfish::dbus_utils::UnpackErrorPrinter(), userInfoMap, "UserPrivilege",
      user_role_ptr, "RemoteUser", remote_user, "UserPasswordExpired",
      password_expired);

  if (!success) {
    asyncResp->res.result(boost::beast::http::status::internal_server_error);
    return false;
  }

  if (user_role_ptr != nullptr) {
    user_role = *user_role_ptr;
    BMCWEB_LOG_DEBUG << "userName = " << req.session->username
                     << " user_role = " << *user_role_ptr;
  }

  if (remote_user == nullptr) {
    BMCWEB_LOG_ERROR << "RemoteUser property missing or wrong type";
    asyncResp->res.result(boost::beast::http::status::internal_server_error);
    return false;
  }
  bool expired = false;
  if (password_expired == nullptr) {
    if (!*remote_user) {
      BMCWEB_LOG_ERROR << "UserPasswordExpired property is expected for"
                          " local user but is missing or wrong type";
      asyncResp->res.result(boost::beast::http::status::internal_server_error);
      return false;
    }
  } else {
    expired = *password_expired;
  }

  // Get the user's privileges from the role
  redfish::Privileges user_privileges = redfish::getUserPrivileges(user_role);

  // Set isConfigureSelfOnly based on D-Bus results.  This
  // ignores the results from both pamAuthenticateUser and the
  // value from any previous use of this session.
  req.session->isConfigureSelfOnly = expired;

  // Modify privileges if isConfigureSelfOnly.
  if (req.session->isConfigureSelfOnly) {
    // Remove all privileges except ConfigureSelf
    user_privileges =
        user_privileges.intersection(redfish::Privileges{"ConfigureSelf"});
    BMCWEB_LOG_DEBUG << "Operation limited to ConfigureSelf";
  }

  if (!rule.checkPrivileges(user_privileges)) {
    asyncResp->res.result(boost::beast::http::status::forbidden);
    if (req.session->isConfigureSelfOnly) {
      redfish::messages::passwordChangeRequired(
          asyncResp->res,
          crow::utility::urlFromPieces("redfish", "v1", "AccountService",
                                       "Accounts", req.session->username));
    }
    return false;
  }

  req.userRole = user_role;

  return true;
}

void Trie::optimizeNode(Node* node) {
  for (size_t x : node->paramChildrens) {
    if (x == 0U) {
      continue;
    }
    Node* child = &nodes[x];
    optimizeNode(child);
  }
  if (node->children.empty()) {
    return;
  }
  bool merge_with_child = true;
  for (const Node::ChildMap::value_type& kv : node->children) {
    Node* child = &nodes[kv.second];
    if (!child->isSimpleNode()) {
      merge_with_child = false;
      break;
    }
  }
  if (merge_with_child) {
    Node::ChildMap merged;
    for (const Node::ChildMap::value_type& kv : node->children) {
      Node* child = &nodes[kv.second];
      for (const Node::ChildMap::value_type& child_kv : child->children) {
        merged[kv.first + child_kv.first] = child_kv.second;
      }
    }
    node->children = std::move(merged);
    optimizeNode(node);
  } else {
    for (const Node::ChildMap::value_type& kv : node->children) {
      Node* child = &nodes[kv.second];
      optimizeNode(child);
    }
  }
}

void Trie::findRouteIndexes(const std::string& reqUrl,
                            std::vector<unsigned>& routeIndexes,
                            const Node* node, unsigned pos) const {
  if (node == nullptr) {
    node = head();
  }
  for (const Node::ChildMap::value_type& kv : node->children) {
    const std::string& fragment = kv.first;
    //  NO_CDC: to be improved later.
    const Node* child = &nodes[kv.second];
    if (pos >= reqUrl.size()) {
      if (child->ruleIndex != 0 && fragment != "/") {
        routeIndexes.push_back(child->ruleIndex);
      }
      findRouteIndexes(reqUrl, routeIndexes, child,
                       static_cast<unsigned>(pos + fragment.size()));
    } else {
      if (reqUrl.compare(pos, fragment.size(), fragment) == 0) {
        findRouteIndexes(reqUrl, routeIndexes, child,
                         static_cast<unsigned>(pos + fragment.size()));
      }
    }
  }
}

std::pair<unsigned int, RoutingParams> Trie::find(std::string_view reqUrl,
                                                  const Node* node, size_t pos,
                                                  RoutingParams* params) const {
  RoutingParams empty;
  if (params == nullptr) {
    params = &empty;
  }

  unsigned found{};
  RoutingParams match_params;

  if (node == nullptr) {
    node = head();
  }
  if (pos == reqUrl.size()) {
    return {node->ruleIndex, *params};
  }

  auto update_found = [&found,
                       &match_params](std::pair<unsigned, RoutingParams>& ret) {
    // Current Logic is to choose the rule with the lowest index in the allRules
    // vector
    bool replaceFound = found == 0U || found > ret.first;

    // With plugins enabled, we will return the first matched
    // found instead. This results in matching fixed URI segments whenever
    // possible relative to wildcards.
#if defined(PLATFORM_PLUGINS_ENABLED) ||      \
    defined(GOOGLE_COMMON_PLUGINS_ENABLED) || \
    defined(GOOGLE_PLUGINS_INTERNAL_ENABLED)
    replaceFound = found == 0U;
#endif

    if (ret.first != 0U && replaceFound) {
      found = ret.first;
      match_params = std::move(ret.second);
    }
  };

  for (const Node::ChildMap::value_type& kv : node->children) {
    const std::string& fragment = kv.first;
    const Node* child = &nodes[kv.second];

    if (reqUrl.compare(pos, fragment.size(), fragment) == 0) {
      std::pair<unsigned, RoutingParams> ret =
          find(reqUrl, child, pos + fragment.size(), params);
      update_found(ret);
    }
  }

  if (node->paramChildrens[static_cast<size_t>(ParamType::INT)] != 0U) {
    char c = reqUrl[pos];
    if ((c >= '0' && c <= '9') || c == '+' || c == '-') {
      char* eptr = nullptr;
      errno = 0;
      int64_t value = std::strtoll(reqUrl.data() + pos, &eptr, 10);  // NOLINT
      if (errno != ERANGE && eptr != reqUrl.data() + pos) {
        params->intParams.push_back(value);
        std::pair<unsigned, RoutingParams> ret = find(
            reqUrl,
            &nodes[node->paramChildrens[static_cast<size_t>(ParamType::INT)]],
            static_cast<size_t>(eptr - reqUrl.data()), params);
        update_found(ret);
        params->intParams.pop_back();
      }
    }
  }

  if (node->paramChildrens[static_cast<size_t>(ParamType::UINT)] != 0U) {
    char c = reqUrl[pos];
    if ((c >= '0' && c <= '9') || c == '+') {
      char* eptr = nullptr;
      errno = 0;
      uint64_t value = std::strtoull(reqUrl.data() + pos, &eptr, 10);  // NOLINT
      if (errno != ERANGE && eptr != reqUrl.data() + pos) {
        params->uintParams.push_back(value);
        std::pair<unsigned, RoutingParams> ret = find(
            reqUrl,
            &nodes[node->paramChildrens[static_cast<size_t>(ParamType::UINT)]],
            static_cast<size_t>(eptr - reqUrl.data()), params);
        update_found(ret);
        params->uintParams.pop_back();
      }
    }
  }

  if (node->paramChildrens[static_cast<size_t>(ParamType::DOUBLE)] != 0U) {
    char c = reqUrl[pos];
    if ((c >= '0' && c <= '9') || c == '+' || c == '-' || c == '.') {
      char* eptr = nullptr;
      errno = 0;
      double value = std::strtod(reqUrl.data() + pos, &eptr);
      if (errno != ERANGE && eptr != reqUrl.data() + pos) {
        params->doubleParams.push_back(value);
        std::pair<unsigned, RoutingParams> ret =
            find(reqUrl,
                 &nodes[node->paramChildrens[static_cast<size_t>(
                     ParamType::DOUBLE)]],
                 static_cast<size_t>(eptr - reqUrl.data()), params);
        update_found(ret);
        params->doubleParams.pop_back();
      }
    }
  }

  if (node->paramChildrens[static_cast<size_t>(ParamType::STRING)] != 0U) {
    size_t epos = pos;
    for (; epos < reqUrl.size(); epos++) {
      if (reqUrl[epos] == '/') {
        break;
      }
    }

    if (epos != pos) {
      params->stringParams.emplace_back(reqUrl.substr(pos, epos - pos));
      std::pair<unsigned, RoutingParams> ret = find(
          reqUrl,
          &nodes[node->paramChildrens[static_cast<size_t>(ParamType::STRING)]],
          epos, params);
      update_found(ret);
      params->stringParams.pop_back();
    }
  }

  if (node->paramChildrens[static_cast<size_t>(ParamType::PATH)] != 0U) {
    size_t epos = reqUrl.size();

    if (epos != pos) {
      params->stringParams.emplace_back(reqUrl.substr(pos, epos - pos));
      std::pair<unsigned, RoutingParams> ret = find(
          reqUrl,
          &nodes[node->paramChildrens[static_cast<size_t>(ParamType::PATH)]],
          epos, params);
      update_found(ret);
      params->stringParams.pop_back();
    }
  }

  return {found, match_params};
}

void Trie::add(const std::string& url, unsigned int ruleIndex) {
  size_t idx = 0;

  for (unsigned i = 0; i < url.size(); i++) {
    char c = url[i];
    if (c == '<') {
      static const std::array<std::pair<ParamType, std::string>, 7>
          // NOLINTNEXTLINE
          paramTraits = {{
              {ParamType::INT, "<int>"},
              {ParamType::UINT, "<uint>"},
              {ParamType::DOUBLE, "<float>"},
              {ParamType::DOUBLE, "<double>"},
              {ParamType::STRING, "<str>"},
              {ParamType::STRING, "<string>"},
              {ParamType::PATH, "<path>"},
          }};

      for (const std::pair<ParamType, std::string>& x : paramTraits) {
        if (url.compare(i, x.second.size(), x.second) == 0) {
          size_t index = static_cast<size_t>(x.first);
          if (nodes[idx].paramChildrens[index] == 0U) {
            unsigned new_node_idx = newNode();
            nodes[idx].paramChildrens[index] = new_node_idx;
          }
          idx = nodes[idx].paramChildrens[index];
          i += static_cast<unsigned>(x.second.size());
          break;
        }
      }

      i--;
    } else {
      std::string piece(&c, 1);
      if (nodes[idx].children.count(piece) == 0U) {
        unsigned new_node_idx = newNode();
        nodes[idx].children.emplace(piece, new_node_idx);
      }
      idx = nodes[idx].children[piece];
    }
  }
  if (nodes[idx].ruleIndex != 0U) {
    throw std::runtime_error("handler already exists for " + url);
  }
  nodes[idx].ruleIndex = ruleIndex;
}

void Trie::debugNodePrint(Node* n, size_t level)

{
  for (size_t i = 0; i < static_cast<size_t>(ParamType::MAX); i++) {
    if (n->paramChildrens[i] != 0U) {
      BMCWEB_LOG_DEBUG << std::string(
          2U * level, ' ') /*<< "("<<n->paramChildrens[i]<<") "*/;
      switch (static_cast<ParamType>(i)) {
        case ParamType::INT:
          BMCWEB_LOG_DEBUG << "<int>";
          break;
        case ParamType::UINT:
          BMCWEB_LOG_DEBUG << "<uint>";
          break;
        case ParamType::DOUBLE:
          BMCWEB_LOG_DEBUG << "<float>";
          break;
        case ParamType::STRING:
          BMCWEB_LOG_DEBUG << "<str>";
          break;
        case ParamType::PATH:
          BMCWEB_LOG_DEBUG << "<path>";
          break;
        case ParamType::MAX:
          BMCWEB_LOG_DEBUG << "<ERROR>";
          break;
      }

      debugNodePrint(&nodes[n->paramChildrens[i]], level + 1);
    }
  }
  for (const Node::ChildMap::value_type& kv : n->children) {
    BMCWEB_LOG_DEBUG << std::string(2U * level,
                                    ' ') /*<< "(" << kv.second << ") "*/
                     << kv.first;
    debugNodePrint(&nodes[kv.second], level + 1);
  }
}

void Router::internalAddRuleObject(const std::string& rule,
                                   BaseRule* ruleObject) {
  if (ruleObject == nullptr) {
    return;
  }
  for (size_t method = 0, method_bit = 1; method <= methodNotAllowedIndex;
       method++, method_bit <<= 1) {
    if ((ruleObject->methodsBitfield & method_bit) > 0U) {
      perMethods[method].rules.push_back(ruleObject);
      perMethods[method].trie.add(
          rule, static_cast<unsigned>(perMethods[method].rules.size() - 1U));
      // directory case:
      //   request to `/about' url matches `/about/' rule
      if (rule.size() > 2 && rule.back() == '/') {
        perMethods[method].trie.add(
            rule.substr(0, rule.size() - 1),
            static_cast<unsigned>(perMethods[method].rules.size() - 1));
      }
    }
  }
}

void Router::validate() {
  for (std::unique_ptr<BaseRule>& rule : allRules) {
    if (rule) {
      std::unique_ptr<BaseRule> upgraded = rule->upgrade();
      if (upgraded) {
        rule = std::move(upgraded);
      }
      rule->validate();
      internalAddRuleObject(rule->rule, rule.get());
    }
  }
  for (PerMethod& per_method : perMethods) {
    per_method.trie.validate();
  }
}

Router::FindRoute Router::findRouteByIndex(std::string_view url,
                                           size_t index) const {
  FindRoute route;
  if (index >= perMethods.size()) {
    BMCWEB_LOG_CRITICAL << "Bad index???";
    return route;
  }
  const PerMethod& per_method = perMethods[index];
  std::pair<unsigned, RoutingParams> found = per_method.trie.find(url);
  if (found.first >= per_method.rules.size()) {
    throw std::runtime_error("Trie internal structure corrupted!");
  }
  // Found a 404 route, switch that in
  if (found.first != 0U) {
    route.rule = per_method.rules[found.first];
    route.params = std::move(found.second);
  }
  return route;
}

Router::FindRouteResponse Router::findRoute(Request& req) const {
  FindRouteResponse find_route;

  std::optional<HttpVerb> verb = httpVerbFromBoost(req.method());
  if (!verb) {
    return find_route;
  }
  size_t req_method_index = static_cast<size_t>(*verb);
  // Check to see if this url exists at any verb
  for (size_t per_method_index = 0; per_method_index <= maxVerbIndex;
       per_method_index++) {
    // Make sure it's safe to deference the array at that index
    static_assert(maxVerbIndex < std::tuple_size_v<decltype(perMethods)>);
    FindRoute route =
        findRouteByIndex(req.url().encoded_path(), per_method_index);
    if (route.rule == nullptr) {
      continue;
    }
    if (!find_route.allowHeader.empty()) {
      find_route.allowHeader += ", ";
    }
    HttpVerb this_verb = static_cast<HttpVerb>(per_method_index);
    find_route.allowHeader += httpVerbToString(this_verb);
    if (per_method_index == req_method_index) {
      find_route.route = route;
    }
  }
  return find_route;
}

void Router::handle(Request& req,
                    const std::shared_ptr<bmcweb::AsyncResp>& asyncResp) {
  std::optional<HttpVerb> verb = httpVerbFromBoost(req.method());
  if (!verb || static_cast<size_t>(*verb) >= perMethods.size()) {
    asyncResp->res.result(boost::beast::http::status::not_found);
    return;
  }

  FindRouteResponse found_route = findRoute(req);

  if (found_route.route.rule == nullptr && catch_all_rule_ == nullptr) {
    // Couldn't find a normal route with any verb, try looking for a 404
    // route
    if (found_route.allowHeader.empty()) {
      found_route.route =
          findRouteByIndex(req.url().encoded_path(), notFoundIndex);
    } else {
      // See if we have a method not allowed (405) handler
      found_route.route =
          findRouteByIndex(req.url().encoded_path(), methodNotAllowedIndex);
    }
  } else if (found_route.route.rule == nullptr) {
    // We found a catch all rule, so we should use that
    found_route.route.rule = catch_all_rule_.get();
  }

  // Fill in the allow header if it's valid
  if (!found_route.allowHeader.empty()) {
    asyncResp->res.addHeader(boost::beast::http::field::allow,
                             found_route.allowHeader);
  }

  // If we couldn't find a real route or a 404 route, return a generic
  // response
  if (found_route.route.rule == nullptr) {
    if (found_route.allowHeader.empty()) {
      asyncResp->res.result(boost::beast::http::status::not_found);
    } else {
      asyncResp->res.result(boost::beast::http::status::method_not_allowed);
    }
    return;
  }

  BaseRule& rule = *found_route.route.rule;
  RoutingParams params = std::move(found_route.route.params);

  BMCWEB_LOG_INFO << "Matched rule '" << rule.rule << "' "
                  << static_cast<uint32_t>(*verb) << " / " << rule.getMethods();

#ifdef BMCWEB_ENABLE_GRPC
  if (!req.fromGrpc ||
      (insecureDisableGrpcRedfishAuthz && req.peer_authenticated)) {
    rule.handle(req, asyncResp, params);
    return;
  }

  // Handles dynamic fine-grained authorization
  std::string_view url_str = {req.url().encoded_path().data(),
                              req.url().encoded_path().size()};

  ::milotic::authz::BmcWebAuthorizerSingleton::RequestState authz_state;
  authz_state.with_trust_bundle = req.with_trust_bundle;
  authz_state.peer_authenticated = req.peer_authenticated;
  authz_state.peer_privileges = req.peer_privileges;

  grpc::Status status =
      ::milotic::authz::BmcWebAuthorizerSingleton::GetInstance().Authorize(
          url_str, req.method(), authz_state);

  if (!status.ok()) {
    asyncResp->res.addHeader("OData-Version", "4.0");
    LOG(WARNING) << "Authorization failure at " << req.url() << ": "
                 << status.error_message();
    ::redfish::messages::resourceAtUriUnauthorized(asyncResp->res, req.url(),
                                                   status.error_message());
    return;
  }

  rule.handle(req, asyncResp, params);
#else
  validatePrivilege(req, asyncResp, rule,
                    [&rule, asyncResp, params](Request& thisReq) mutable {
                      rule.handle(thisReq, asyncResp, params);
                    });
#endif
}

void Router::debugPrint() {
  for (size_t i = 0; i < perMethods.size(); i++) {
    BMCWEB_LOG_DEBUG << boost::beast::http::to_string(
        static_cast<boost::beast::http::verb>(i));
    perMethods[i].trie.debugPrint();
  }
}

std::vector<const std::string*> Router::getRoutes(const std::string& parent)

{
  std::vector<const std::string*> ret;

  for (const PerMethod& pm : perMethods) {
    std::vector<unsigned> x;
    pm.trie.findRouteIndexes(parent, x);
    for (unsigned index : x) {
      ret.push_back(&pm.rules[index]->rule);
    }
  }
  return ret;
}

void Router::throwHandlerNotFound(std::string_view url,
                                  boost::beast::http::verb boostVerb) {
  std::string error = "handler doesn't exist for ";
  error += url;
  error += ", method ";
  error += boostVerbToString(boostVerb);
  throw std::runtime_error(error);
}

}  // namespace crow
