// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "ec_util.hpp"

#include <span>
#include <stdplus/print.hpp>
#include <stdplus/raw.hpp>
#include <vector>
#include <xyz/openbmc_project/Control/Hoth/error.hpp>

#include "google3/host_commands.h"
#include "host_command.hpp"
#include "message_util.hpp"
#include "payload_update.hpp"

namespace google {
namespace hoth {
namespace internal {
using sdbusplus::error::xyz::openbmc_project::control::hoth::ResponseFailure;

namespace {
constexpr std::string_view kPanicDataMagic = "Pnc!";
constexpr uint32_t kPanicDataSize = 144;

bool matchPersistentPanicMagic(std::span<const uint8_t> response_body) {
  if (response_body.size() < kPanicDataSize) {
    return false;
  }

  return (response_body[kPanicDataSize - 4] == kPanicDataMagic[0] &&
          response_body[kPanicDataSize - 3] == kPanicDataMagic[1] &&
          response_body[kPanicDataSize - 2] == kPanicDataMagic[2] &&
          response_body[kPanicDataSize - 1] == kPanicDataMagic[3]);
}

}  // namespace

class EcException {
 public:
  explicit EcException(uint16_t result) : result(result) {}
  uint16_t get_result() const { return result; }

 private:
  uint16_t result;
};

[[nodiscard]] std::span<const uint8_t> EcUtilImpl::getResponseBody(
    std::vector<uint8_t>& response) {
  std::span<const uint8_t> output = response;
  auto& rsp = stdplus::raw::extractRef<RspHeader>(output);
  if (rsp.result != EC_RES_SUCCESS) {
    throw EcException(rsp.result);
  }

  return output;
}

ec_response_statistics EcUtilImpl::getHothStatistics() const {
  std::vector<uint8_t> response = hostCmd->sendCommand(
      EC_CMD_BOARD_SPECIFIC_BASE + EC_PRV_CMD_HOTH_GET_STATISTICS, kVersionZero,
      /*request=*/nullptr, /*requestSize=*/0);
  std::span<const uint8_t> response_body;
  try {
    response_body = getResponseBody(response);
  } catch (const EcException& ec_exception) {
    stdplus::print(stderr, "{} received a bad response from Hoth {:#x}\n",
                   __func__, static_cast<uint8_t>(ec_exception.get_result()));
    throw ResponseFailure();
  }
  return stdplus::raw::copyFrom<ec_response_statistics>(response_body);
}

ec_response_chip_info EcUtilImpl::getHothChipInfo() const {
  std::vector<uint8_t> response = hostCmd->sendCommand(
      EC_CMD_BOARD_SPECIFIC_BASE + EC_PRV_CMD_HOTH_CHIP_INFO, kVersionZero,
      /*request=*/nullptr, /*requestSize=*/0);
  std::span<const uint8_t> response_body = getResponseBody(response);
  return stdplus::raw::copyFrom<ec_response_chip_info>(response_body);
}

std::optional<ec_response_persistent_panic_info>
EcUtilImpl::getHothPersistentPanicInfo() const {
  // ec_response_persistent_panic_info is 6KiB. Declare the return value this
  // way to leverage NRVO.
  std::optional<ec_response_persistent_panic_info> panic;
  panic.emplace();

  std::span<uint8_t> panic_buf(reinterpret_cast<uint8_t*>(&panic.value()),
                               sizeof(panic.value()));

  // The persistent panic info record is 6KiB long, so we have to retrieve it
  // in chunks.
  const size_t chunk_size = HOTH_PERSISTENT_PANIC_INFO_CHUNK_SIZE;
  const size_t num_chunks = sizeof(panic.value()) / chunk_size;
  auto ptr = panic_buf.begin();
  for (size_t i = 0; i < num_chunks; ++i, ptr += chunk_size) {
    ec_request_persistent_panic_info req = {
        .operation = PERSISTENT_PANIC_INFO_GET,
        .index = i,
    };

    std::vector<uint8_t> response = hostCmd->sendCommand(
        EC_CMD_BOARD_SPECIFIC_BASE + EC_PRV_CMD_HOTH_PERSISTENT_PANIC_INFO,
        kVersionZero, &req, sizeof(req));
    std::span<const uint8_t> response_body;
    try {
      response_body = getResponseBody(response);
    } catch (const EcException& ec_exception) {
      stdplus::print(stderr,
                     "{} received a bad response "
                     "from Hoth {:#x}\n",
                     __func__, static_cast<uint8_t>(ec_exception.get_result()));
      throw ResponseFailure();
    }
    if (response_body.size() != chunk_size) {
      stdplus::print(stderr, "Bad response length %d (expected %d)\n",
                     response_body.size(), chunk_size);
      throw ResponseFailure();
    }

    // The first chunk should contain a panic magic in the last 4 bytes in
    // the panic_record.
    if (i == 0) {
      if (!matchPersistentPanicMagic(response_body)) {
        panic.reset();
        return panic;
      }
    }

    std::copy(response_body.begin(), response_body.end(), ptr);
  }
  return panic;
}

bool EcUtilImpl::checkHothPersistentPanicInfo() const {
  ec_request_persistent_panic_info req = {
      .operation = PERSISTENT_PANIC_INFO_GET,
      .index = 0,
  };

  std::vector<uint8_t> response = hostCmd->sendCommand(
      EC_CMD_BOARD_SPECIFIC_BASE + EC_PRV_CMD_HOTH_PERSISTENT_PANIC_INFO,
      kVersionZero, &req, sizeof(req));
  std::span<const uint8_t> response_body;
  try {
    response_body = getResponseBody(response);
  } catch (const EcException& ec_exception) {
    stdplus::print(stderr,
                   "{} received a bad response "
                   "from Hoth {:#x}\n",
                   __func__, static_cast<uint8_t>(ec_exception.get_result()));
    throw ResponseFailure();
  }
  const size_t chunk_size = HOTH_PERSISTENT_PANIC_INFO_CHUNK_SIZE;
  if (response_body.size() != chunk_size) {
    stdplus::print(stderr, "Bad response length %d (expected %d)\n",
                   response_body.size(), chunk_size);
    throw ResponseFailure();
  }

  return matchPersistentPanicMagic(response_body);
}

ec_authz_record_get_response EcUtilImpl::getHothAuthRecord() const {
  ec_authz_record_get_request request = {
      .index = 0,
      .reserved = {},
  };

  std::vector<uint8_t> response = hostCmd->sendCommand(
      EC_CMD_BOARD_SPECIFIC_BASE + EC_PRV_CMD_HOTH_GET_AUTHZ_RECORD,
      kVersionZero, &request, sizeof(request));
  std::span<const uint8_t> response_body;
  try {
    response_body = getResponseBody(response);
  } catch (const EcException& ec_exception) {
    if (ec_exception.get_result() == EC_RES_UNAVAILABLE) {
      // Unsupported, probably a Haven, do not log.
    } else {
      stdplus::print(stderr, "{} received a bad response from Hoth {:#x}\n",
                     __func__, static_cast<uint8_t>(ec_exception.get_result()));
    }
    throw ResponseFailure();
  }
  return stdplus::raw::copyFrom<ec_authz_record_get_response>(response_body);
}

ec_response_key_rotation_status EcUtilImpl::getHothKeyRotationStatus() const {
  ec_request_key_rotation_record request = {
      .operation = KEY_ROTATION_RECORD_GET_STATUS,
      .packet_offset = 0,
      .packet_size = 0,
      .reserved = 0,
  };

  std::vector<uint8_t> response = hostCmd->sendCommand(
      EC_CMD_BOARD_SPECIFIC_BASE + EC_PRV_CMD_HOTH_KEY_ROTATION_OP,
      kVersionZero, &request, sizeof(request));

  std::span<const uint8_t> response_body;
  try {
    response_body = getResponseBody(response);
  } catch (const EcException& ec_exception) {
    if (ec_exception.get_result() == EC_RES_INVALID_COMMAND) {
      // TODO: b/434689086 - Do not log error to prevent spam whilst KeyRotation
      // support is being deployed.
    } else {
      stdplus::print(stderr,
                     "{} received a bad response from "
                     "Hoth {:#x}\n",
                     __func__, static_cast<uint8_t>(ec_exception.get_result()));
    }
    throw ResponseFailure();
  }
  return stdplus::raw::copyFrom<ec_response_key_rotation_status>(response_body);
}

}  // namespace internal

}  // namespace hoth

}  // namespace google
