#include "utility.hpp"

#include <openssl/crypto.h>

#include <array>
#include <cstddef>
#include <cstdint>
#include <initializer_list>
#include <limits>
#include <string>
#include <string_view>
#include <variant>

#include "bmcweb_config.h"

namespace crow {

namespace utility {

std::string setProtocolDefaults(boost::urls::url_view urlView)  // NOLINT
{
  if (urlView.scheme() == "https") {
    return "https";
  }
  if (urlView.scheme() == "http") {
    if (bmcwebInsecureEnableHttpPushStyleEventing) {
      return "http";
    }
    return "";
  }
  return "";
}

bool constantTimeStringCompare(std::string_view a, std::string_view b) {
  // Important note, this function is ONLY constant time if the two input
  // sizes are the same
  if (a.size() != b.size()) {
    return false;
  }
  return CRYPTO_memcmp(a.data(), b.data(), a.size()) == 0;
}

namespace details {
bool readUrlSegments(boost::urls::url_view url,  // NOLINT
                     std::initializer_list<UrlSegment>&& segments) {
  boost::urls::segments_view urlSegments = url.segments();

  if (!urlSegments.is_absolute()) {
    return false;
  }

  boost::urls::segments_view::iterator it = urlSegments.begin();
  boost::urls::segments_view::iterator end = urlSegments.end();

  for (const auto& segment : segments) {
    if (it == end) {
      // If the request ends with an "any" path, this was successful
      return std::holds_alternative<OrMorePaths>(segment);
    }
    UrlParseResult res = std::visit(UrlSegmentMatcherVisitor(*it), segment);
    if (res == UrlParseResult::Done) {
      return true;
    }
    if (res == UrlParseResult::Fail) {
      return false;
    }
    it++;
  }

  // There will be an empty segment at the end if the URI ends with a "/"
  // e.g. /redfish/v1/Chassis/
  if ((it != end) && urlSegments.back().empty()) {
    it++;
  }
  return it == end;
}

}  // namespace details

boost::urls::url replaceUrlSegment(boost::urls::url_view urlView,  // NOLINT
                                   const uint replaceLoc,
                                   std::string_view newSegment) {
  boost::urls::segments_view urlSegments = urlView.segments();
  boost::urls::url url("/");

  if (!urlSegments.is_absolute()) {
    return url;
  }

  boost::urls::segments_view::iterator it = urlSegments.begin();
  boost::urls::segments_view::iterator end = urlSegments.end();

  for (uint idx = 0; it != end; it++, idx++) {
    if (idx == replaceLoc) {
      url.segments().push_back(newSegment);
    } else {
      url.segments().push_back(*it);
    }
  }

  return url;
}

uint16_t setPortDefaults(boost::urls::url_view url)  // NOLINT
{
  uint16_t port = url.port_number();
  if (port != 0) {
    // user picked a port already.
    return port;
  }

  // If the user hasn't explicitly stated a port, pick one explicitly for them
  // based on the protocol defaults
  if (url.scheme() == "http") {
    return 80;
  }
  if (url.scheme() == "https") {
    return 443;
  }
  return 0;
}

bool validateAndSplitUrl(std::string_view destUrl, std::string& urlProto,
                         std::string& host, uint16_t& port, std::string& path) {
  auto url = boost::urls::parse_uri(destUrl);
  if (!url) {
    return false;
  }
  urlProto = setProtocolDefaults(url.value());
  if (urlProto.empty()) {
    return false;
  }

  port = setPortDefaults(url.value());

  host = url->encoded_host();

  path = url->encoded_path();
  if (path.empty()) {
    path = "/";
  }
  if (url->has_fragment()) {
    path += '#';
    path += url->encoded_fragment();
  }

  if (url->has_query()) {
    path += '?';
    path += url->encoded_query();
  }

  return true;
}

std::string base64encode(std::string_view data) {
  const std::array<char, 64> key = {
      'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
      'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
      'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
      'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
      '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'};

  size_t size = data.size();
  std::string ret;
  ret.resize((size + 2) / 3 * 4);
  auto it = ret.begin();

  size_t i = 0;
  while (i < size) {
    size_t keyIndex = 0;

    keyIndex = static_cast<size_t>(data[i] & 0xFC) >> 2;
    *it++ = key[keyIndex];

    if (i + 1 < size) {
      keyIndex = static_cast<size_t>(data[i] & 0x03) << 4;
      keyIndex += static_cast<size_t>(data[i + 1] & 0xF0) >> 4;
      *it++ = key[keyIndex];

      if (i + 2 < size) {
        keyIndex = static_cast<size_t>(data[i + 1] & 0x0F) << 2;
        keyIndex += static_cast<size_t>(data[i + 2] & 0xC0) >> 6;
        *it++ = key[keyIndex];

        keyIndex = static_cast<size_t>(data[i + 2] & 0x3F);
        *it++ = key[keyIndex];
      } else {
        keyIndex = static_cast<size_t>(data[i + 1] & 0x0F) << 2;
        *it++ = key[keyIndex];
        *it++ = '=';
      }
    } else {
      keyIndex = static_cast<size_t>(data[i] & 0x03) << 4;
      *it++ = key[keyIndex];
      *it++ = '=';
      *it++ = '=';
    }

    i += 3;
  }

  return ret;
}

bool base64Decode(std::string_view input, std::string& output) {
  static const char nop = static_cast<char>(-1);
  // See note on encoding_data[] in above function
  static const std::array<char, 256> decodingData = {
      nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop,
      nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop,
      nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, 62,  nop,
      nop, nop, 63,  52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  nop, nop,
      nop, nop, nop, nop, nop, 0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
      10,  11,  12,  13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,
      25,  nop, nop, nop, nop, nop, nop, 26,  27,  28,  29,  30,  31,  32,  33,
      34,  35,  36,  37,  38,  39,  40,  41,  42,  43,  44,  45,  46,  47,  48,
      49,  50,  51,  nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop,
      nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop,
      nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop,
      nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop,
      nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop,
      nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop,
      nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop,
      nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop,
      nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop, nop,
      nop};

  size_t inputLength = input.size();

  // allocate space for output string
  output.clear();
  output.reserve(((inputLength + 2) / 3) * 4);

  auto getCodeValue = [](char c) {
    auto code = static_cast<unsigned char>(c);
    // Ensure we cannot index outside the bounds of the decoding array
    static_assert(std::numeric_limits<decltype(code)>::max() <
                  decodingData.size());
    return decodingData[code];
  };

  // for each 4-bytes sequence from the input, extract 4 6-bits sequences by
  // dropping first two bits
  // and regenerate into 3 8-bits sequences

  for (size_t i = 0; i < inputLength; i++) {
    char base64code0 = 0;
    char base64code1 = 0;
    char base64code2 = 0;  // initialized to 0 to suppress warnings

    base64code0 = getCodeValue(input[i]);
    if (base64code0 == nop) {  // non base64 character
      return false;
    }
    if (!(++i < inputLength)) {  // we need at least two input bytes for first
                                 // byte output
      return false;
    }
    base64code1 = getCodeValue(input[i]);
    if (base64code1 == nop) {  // non base64 character
      return false;
    }
    output +=
        static_cast<char>((base64code0 << 2) | ((base64code1 >> 4) & 0x3));

    if (++i < inputLength) {
      char c = input[i];
      if (c == '=') {  // padding , end of input
        return (base64code1 & 0x0f) == 0;
      }
      base64code2 = getCodeValue(input[i]);
      if (base64code2 == nop) {  // non base64 character
        return false;
      }
      output += static_cast<char>(((base64code1 << 4) & 0xf0) |
                                  ((base64code2 >> 2) & 0x0f));
    }

    if (++i < inputLength) {
      char c = input[i];
      if (c == '=') {  // padding , end of input
        return (base64code2 & 0x03) == 0;
      }
      char base64code3 = getCodeValue(input[i]);
      if (base64code3 == nop) {  // non base64 character
        return false;
      }
      output += static_cast<char>((((base64code2 << 6) & 0xc0) | base64code3));
    }
  }

  return true;
}

}  // namespace utility
}  // namespace crow

namespace crow {
namespace black_magic {
bool isParameterTagCompatible(uint64_t a, uint64_t b) {
  while (true) {
    if (a == 0 && b == 0) {
      // Both tags were equivalent, parameters are compatible
      return true;
    }
    if (a == 0 || b == 0) {
      // one of the tags had more parameters than the other
      return false;
    }
    TypeCode sa = static_cast<TypeCode>(a % toUnderlying(TypeCode::Max));
    TypeCode sb = static_cast<TypeCode>(b % toUnderlying(TypeCode::Max));

    if (sa == TypeCode::Path) {
      sa = TypeCode::String;
    }
    if (sb == TypeCode::Path) {
      sb = TypeCode::String;
    }
    if (sa != sb) {
      return false;
    }
    a /= toUnderlying(TypeCode::Max);
    b /= toUnderlying(TypeCode::Max);
  }
}
}  // namespace black_magic
}  // namespace crow
