// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "ec_util.hpp"

#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <xyz/openbmc_project/Control/Hoth/error.hpp>

#include "google3/host_commands.h"
#include "host_command_mock.hpp"
#include "payload_update.hpp"

// NOLINTNEXTLINE(google-build-using-namespace)
using namespace std::literals;
using sdbusplus::error::xyz::openbmc_project::control::hoth::ResponseFailure;

using ::testing::_;
using ::testing::Return;

namespace google {
namespace hoth {
namespace internal {
namespace {

auto const goodResponseStr = "\x03\xfd\x00\x00\x00\x00\x00\x00"s;

class EcUtilTest : public ::testing::Test {
 protected:
  EcUtilTest() : ecUtil(&hostCmd) {}

  // Host Command interface handle
  internal::HostCommandMock hostCmd;

  // EcUtil interface handle
  internal::EcUtilImpl ecUtil;
};

MATCHER_P(dummyMatches, req, "") {
  const auto* const arg_req = static_cast<const uint8_t*>(arg);
  return *arg_req == req;
}

class EcUtilStatisticTest : public EcUtilTest {};

TEST_F(EcUtilStatisticTest, sendCommandReturnsBadResultFails) {
  std::vector<uint8_t> rsp(goodResponseStr.begin(), goodResponseStr.end());
  // Change the RspHeader.result to something other than EC_RES_SUCCESS
  rsp[2] = internal::EC_RES_ERROR;

  EXPECT_CALL(hostCmd, sendCommand(EC_CMD_BOARD_SPECIFIC_BASE +
                                       EC_PRV_CMD_HOTH_GET_STATISTICS,
                                   ecUtil.kVersionZero, nullptr, 0))
      .WillOnce(Return(rsp));

  EXPECT_THROW(ecUtil.getHothStatistics(), CommandRunException);
}

TEST_F(EcUtilStatisticTest, sendCommandReturnsGoodResponseSuccess) {
  std::vector<uint8_t> rsp(goodResponseStr.begin(), goodResponseStr.end());
  for (uint16_t i = 0; i < 256; i++) {
    rsp.push_back(0);
  }

  EXPECT_CALL(hostCmd, sendCommand(EC_CMD_BOARD_SPECIFIC_BASE +
                                       EC_PRV_CMD_HOTH_GET_STATISTICS,
                                   ecUtil.kVersionZero, nullptr, 0))
      .WillOnce(Return(rsp));

  EXPECT_NO_THROW(ecUtil.getHothStatistics());
}

class EcUtilPersistentPanicTest : public EcUtilTest {};

struct panic_host_command_response {
  RspHeader hdr;
  uint8_t body[HOTH_PERSISTENT_PANIC_INFO_CHUNK_SIZE];
};

// Template of a persistent panic response, with the correct data_len, checksum
// and panic magic.
const panic_host_command_response kPanicResponseTemplate = {
    .hdr =
        {
            .struct_version = 3,
            .checksum = 0xb9,
            .result = 0,
            .data_len = HOTH_PERSISTENT_PANIC_INFO_CHUNK_SIZE,
            .reserved = 0,
        },
    .body =
        {
            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,    0,    0,    0,
            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,    0,    0,    0,
            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,    0,    0,    0,
            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,    0,    0,    0,
            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,    0,    0,    0,
            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,    0,    0,    0,
            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,    0,    0,    0,
            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x50, 0x6e, 0x63, 0x21,
        },
};

TEST_F(EcUtilPersistentPanicTest, incorrectResponseSizeThrows) {
  std::vector<uint8_t> rsp_buf(sizeof(kPanicResponseTemplate), 0);
  auto* rsp = reinterpret_cast<panic_host_command_response*>(rsp_buf.data());
  *rsp = kPanicResponseTemplate;
  rsp->hdr.data_len -= 1;
  rsp->hdr.checksum += 1;
  rsp_buf.resize(sizeof(kPanicResponseTemplate) - 1);
  EXPECT_CALL(hostCmd, sendCommand(EC_CMD_BOARD_SPECIFIC_BASE +
                                       EC_PRV_CMD_HOTH_PERSISTENT_PANIC_INFO,
                                   ecUtil.kVersionZero, _,
                                   sizeof(ec_request_persistent_panic_info)))
      .Times(2)
      .WillRepeatedly(Return(rsp_buf));
  EXPECT_THROW(ecUtil.checkHothPersistentPanicInfo(), ResponseFailure);
  EXPECT_THROW(ecUtil.getHothPersistentPanicInfo(), ResponseFailure);
}

TEST_F(EcUtilPersistentPanicTest, incorrectPanicMagicReturnsNullopt) {
  std::vector<uint8_t> rsp_buf(sizeof(kPanicResponseTemplate), 0);
  auto* rsp = reinterpret_cast<panic_host_command_response*>(rsp_buf.data());
  *rsp = kPanicResponseTemplate;
  rsp->body[143] -= 1;
  rsp->hdr.checksum += 1;
  EXPECT_CALL(hostCmd, sendCommand(EC_CMD_BOARD_SPECIFIC_BASE +
                                       EC_PRV_CMD_HOTH_PERSISTENT_PANIC_INFO,
                                   ecUtil.kVersionZero, _,
                                   sizeof(ec_request_persistent_panic_info)))
      .Times(2)
      .WillRepeatedly(Return(rsp_buf));
  EXPECT_FALSE(ecUtil.checkHothPersistentPanicInfo());
  EXPECT_FALSE(ecUtil.getHothPersistentPanicInfo());
}

TEST_F(EcUtilPersistentPanicTest, correctHostCommandReturnsFullPanicRecord) {
  std::vector<uint8_t> rsp_bufs[12];
  for (int i = 0; i < 12; ++i) {
    rsp_bufs[i].resize(sizeof(panic_host_command_response));
    auto* rsp =
        reinterpret_cast<panic_host_command_response*>(rsp_bufs[i].data());
    *rsp = kPanicResponseTemplate;
    // Tweak the response a little bit to make each chunk slightly
    // different.
    rsp->body[0] += i;
    rsp->body[HOTH_PERSISTENT_PANIC_INFO_CHUNK_SIZE - 1] -= i;
  }
  EXPECT_CALL(hostCmd, sendCommand(EC_CMD_BOARD_SPECIFIC_BASE +
                                       EC_PRV_CMD_HOTH_PERSISTENT_PANIC_INFO,
                                   ecUtil.kVersionZero, _,
                                   sizeof(ec_request_persistent_panic_info)))
      .Times(13)
      .WillOnce(Return(rsp_bufs[0]))
      .WillOnce(Return(rsp_bufs[0]))
      .WillOnce(Return(rsp_bufs[1]))
      .WillOnce(Return(rsp_bufs[2]))
      .WillOnce(Return(rsp_bufs[3]))
      .WillOnce(Return(rsp_bufs[4]))
      .WillOnce(Return(rsp_bufs[5]))
      .WillOnce(Return(rsp_bufs[6]))
      .WillOnce(Return(rsp_bufs[7]))
      .WillOnce(Return(rsp_bufs[8]))
      .WillOnce(Return(rsp_bufs[9]))
      .WillOnce(Return(rsp_bufs[10]))
      .WillOnce(Return(rsp_bufs[11]));

  EXPECT_TRUE(ecUtil.checkHothPersistentPanicInfo());

  auto panic = ecUtil.getHothPersistentPanicInfo();
  std::span<uint8_t> panic_buf(reinterpret_cast<uint8_t*>(&panic.value()),
                               sizeof(panic.value()));
  for (uint8_t i = 0; i < 12; ++i) {
    size_t chunk_start =
        static_cast<uint32_t>(i) * HOTH_PERSISTENT_PANIC_INFO_CHUNK_SIZE;
    EXPECT_EQ(panic_buf[chunk_start], i);
    EXPECT_EQ(
        panic_buf[chunk_start + HOTH_PERSISTENT_PANIC_INFO_CHUNK_SIZE - 1],
        static_cast<uint8_t>(-i));
  }
}

class EcUtilAuthRecordTest : public EcUtilTest {};

TEST_F(EcUtilAuthRecordTest, authRecordNotSupported) {
  std::vector<uint8_t> rsp(goodResponseStr.begin(), goodResponseStr.end());
  rsp.push_back(0);
  EXPECT_CALL(hostCmd,
              sendCommand(EC_CMD_BOARD_SPECIFIC_BASE +
                              EC_PRV_CMD_HOTH_IS_HOST_COMMAND_SUPPORTED,
                          _, _, _))
      .WillOnce(Return(rsp));
  EXPECT_THROW(ecUtil.getHothAuthRecord(), CommandNotSupportedException);
}

class EcUtilKeyRotationTest : public EcUtilTest {};

TEST_F(EcUtilKeyRotationTest, authRecordNotSupported) {
  std::vector<uint8_t> rsp(goodResponseStr.begin(), goodResponseStr.end());
  rsp.push_back(0);
  EXPECT_CALL(hostCmd,
              sendCommand(EC_CMD_BOARD_SPECIFIC_BASE +
                              EC_PRV_CMD_HOTH_IS_HOST_COMMAND_SUPPORTED,
                          _, _, _))
      .WillOnce(Return(rsp));
  EXPECT_THROW(ecUtil.getHothKeyRotationStatus(), CommandNotSupportedException);
}

class EcUtilSecureBootEnforcementTest : public EcUtilTest {};

TEST_F(EcUtilSecureBootEnforcementTest, commandNotSupported) {
  std::vector<uint8_t> rsp(goodResponseStr.begin(), goodResponseStr.end());
  rsp.push_back(0);
  EXPECT_CALL(hostCmd,
              sendCommand(EC_CMD_BOARD_SPECIFIC_BASE +
                              EC_PRV_CMD_HOTH_IS_HOST_COMMAND_SUPPORTED,
                          _, _, _))
      .WillOnce(Return(rsp));
  EXPECT_THROW(ecUtil.getSecureBootEnforcementState(),
               CommandNotSupportedException);
}

TEST_F(EcUtilSecureBootEnforcementTest, sendCommandReturnsBadResultFails) {
  std::vector<uint8_t> rsp_supported(goodResponseStr.begin(),
                                     goodResponseStr.end());
  rsp_supported.push_back(1);
  EXPECT_CALL(hostCmd,
              sendCommand(EC_CMD_BOARD_SPECIFIC_BASE +
                              EC_PRV_CMD_HOTH_IS_HOST_COMMAND_SUPPORTED,
                          _, _, _))
      .WillOnce(Return(rsp_supported));

  std::vector<uint8_t> rsp(goodResponseStr.begin(), goodResponseStr.end());
  // Change the RspHeader.result to something other than EC_RES_SUCCESS
  rsp[2] = internal::EC_RES_ERROR;

  EXPECT_CALL(
      hostCmd,
      sendCommand(
          EC_CMD_BOARD_SPECIFIC_BASE + EC_PRV_CMD_HOTH_GET_SECURE_BOOT_ENFORCEMENT,
          ecUtil.kVersionZero, nullptr, 0))
      .WillOnce(Return(rsp));

  EXPECT_THROW(ecUtil.getSecureBootEnforcementState(), CommandRunException);
}

TEST_F(EcUtilSecureBootEnforcementTest, sendCommandReturnsGoodResponseSuccess) {
  std::vector<uint8_t> rsp_supported(goodResponseStr.begin(),
                                     goodResponseStr.end());
  rsp_supported.push_back(1);
  EXPECT_CALL(hostCmd,
              sendCommand(EC_CMD_BOARD_SPECIFIC_BASE +
                              EC_PRV_CMD_HOTH_IS_HOST_COMMAND_SUPPORTED,
                          _, _, _))
      .WillOnce(Return(rsp_supported));

  std::vector<uint8_t> rsp(goodResponseStr.begin(), goodResponseStr.end());
  secure_boot_enforcement_state expected = {.enabled = 1, .reserved0 = {}};
  const uint8_t* body_ptr = reinterpret_cast<const uint8_t*>(&expected);
  rsp.insert(rsp.end(), body_ptr, body_ptr + sizeof(expected));

  EXPECT_CALL(
      hostCmd,
      sendCommand(
          EC_CMD_BOARD_SPECIFIC_BASE + EC_PRV_CMD_HOTH_GET_SECURE_BOOT_ENFORCEMENT,
          ecUtil.kVersionZero, nullptr, 0))
      .WillOnce(Return(rsp));

  secure_boot_enforcement_state actual = ecUtil.getSecureBootEnforcementState();
  EXPECT_EQ(actual.enabled, expected.enabled);
}

}  // namespace
}  // namespace internal
}  // namespace hoth
}  // namespace google
