// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "message_util.hpp"

#include <stdplus/raw.hpp>
#include <xyz/openbmc_project/Control/Hoth/error.hpp>

#include <algorithm>
#include <cstring>
#include <numeric>

#include <gmock/gmock.h>
#include <gtest/gtest.h>

using namespace google::hoth::internal;
using namespace sdbusplus::error::xyz::openbmc_project::control::hoth;

namespace
{

void convertReq(std::vector<uint8_t>& req, const struct ReqHeader& hdr)
{
    req.resize(sizeof(hdr));
    std::memcpy(req.data(), &hdr, req.size());
}

void convertRsp(std::vector<uint8_t>& rsp, const struct RspHeader& hdr)
{
    rsp.resize(sizeof(hdr));
    std::memcpy(rsp.data(), &hdr, rsp.size());
}

} // namespace

TEST(ReqLen, empty)
{
    std::vector<uint8_t> req;

    EXPECT_THROW(reqLen(req), CommandFailure);
}

TEST(ReqLen, wrongStructVersion)
{
    std::vector<uint8_t> req;
    struct ReqHeader hdr = {};
    hdr.struct_version = 42;

    convertReq(req, hdr);

    EXPECT_THROW(reqLen(req), CommandFailure);
}

TEST(ReqLen, lenZero)
{
    std::vector<uint8_t> req;
    struct ReqHeader hdr = {};
    hdr.struct_version = SUPPORTED_STRUCT_VERSION;

    convertReq(req, hdr);

    EXPECT_EQ(sizeof(hdr), reqLen(req));
}

TEST(ReqLen, lenZeroTrailing)
{
    std::vector<uint8_t> req;
    struct ReqHeader hdr = {};
    hdr.struct_version = SUPPORTED_STRUCT_VERSION;

    convertReq(req, hdr);
    req.resize(req.size() + 5, 42);

    EXPECT_EQ(sizeof(hdr), reqLen(req));
}

TEST(ReqLen, lenOne)
{
    std::vector<uint8_t> req;
    struct ReqHeader hdr = {};
    hdr.struct_version = SUPPORTED_STRUCT_VERSION;
    constexpr uint16_t size = 1;

    hdr.data_len = size;
    convertReq(req, hdr);

    EXPECT_EQ(sizeof(hdr) + size, reqLen(req));
}

TEST(ReqLen, lenBig)
{
    std::vector<uint8_t> req;
    struct ReqHeader hdr = {};
    hdr.struct_version = SUPPORTED_STRUCT_VERSION;
    constexpr uint16_t size = UINT8_MAX + 9001;

    hdr.data_len = size;
    convertReq(req, hdr);

    EXPECT_EQ(sizeof(hdr) + size, reqLen(req));
}

TEST(ReqLen, lenMax)
{
    std::vector<uint8_t> req;
    struct ReqHeader hdr = {};
    hdr.struct_version = SUPPORTED_STRUCT_VERSION;
    constexpr uint16_t size = UINT16_MAX - sizeof(hdr);

    hdr.data_len = size;
    convertReq(req, hdr);

    EXPECT_EQ(sizeof(hdr) + size, reqLen(req));
}

TEST(RspLen, empty)
{
    std::vector<uint8_t> rsp;

    EXPECT_THROW(rspLen(rsp), ResponseFailure);
}

TEST(RspLen, wrongStructVersion)
{
    std::vector<uint8_t> rsp;
    struct RspHeader hdr = {};
    hdr.struct_version = 42;

    convertRsp(rsp, hdr);

    EXPECT_THROW(rspLen(rsp), ResponseFailure);
}

TEST(RspLen, lenZero)
{
    std::vector<uint8_t> rsp;
    struct RspHeader hdr = {};
    hdr.struct_version = SUPPORTED_STRUCT_VERSION;

    convertRsp(rsp, hdr);

    EXPECT_EQ(sizeof(hdr), rspLen(rsp));
}

TEST(RspLen, lenZeroTrailing)
{
    std::vector<uint8_t> rsp;
    struct RspHeader hdr = {};
    hdr.struct_version = SUPPORTED_STRUCT_VERSION;

    convertRsp(rsp, hdr);
    rsp.resize(rsp.size() + 5, 42);

    EXPECT_EQ(sizeof(hdr), rspLen(rsp));
}

TEST(RspLen, lenOne)
{
    std::vector<uint8_t> rsp;
    struct RspHeader hdr = {};
    hdr.struct_version = SUPPORTED_STRUCT_VERSION;
    constexpr uint16_t size = 1;

    hdr.data_len = size;
    convertRsp(rsp, hdr);

    EXPECT_EQ(sizeof(hdr) + size, rspLen(rsp));
}

TEST(RspLen, lenBig)
{
    std::vector<uint8_t> rsp;
    struct RspHeader hdr = {};
    hdr.struct_version = SUPPORTED_STRUCT_VERSION;
    constexpr uint16_t size = UINT8_MAX + 9001;

    hdr.data_len = size;
    convertRsp(rsp, hdr);

    EXPECT_EQ(sizeof(hdr) + size, rspLen(rsp));
}

TEST(RspLen, lenMax)
{
    std::vector<uint8_t> rsp;
    struct RspHeader hdr = {};
    hdr.struct_version = SUPPORTED_STRUCT_VERSION;
    constexpr uint16_t size = UINT16_MAX - sizeof(hdr);

    hdr.data_len = size;
    convertRsp(rsp, hdr);

    EXPECT_EQ(sizeof(hdr) + size, rspLen(rsp));
}

TEST(CalculateChecksum, calculateChecksumSuccess)
{
    // Tests to ensure the checksummer is running correctly.

    uint8_t header[] = {1, 3, 5, 7};
    uint8_t body[] = {2, 4, 6, 8};

    // (1 + 3 + 5 + 7) + (2 + 4 + 6 + 8) = 36
    EXPECT_EQ(36, calculateChecksum(header, body));

    // Now test what happens if the sum exceeds 256.
    body[0] = 200;
    body[1] = 202;

    // (1 + 3 + 5 + 7) + (200 + 202 + 6 + 8) = 432
    // 432 % 256 = 176
    EXPECT_EQ(176, calculateChecksum(header, body));
}

TEST(PopulateReqHeader, populateReqHeaderSuccess)
{
    ASSERT_THROW(populateReqHeader(0, 0, nullptr, 0, nullptr), CommandFailure);

    // NOLINTNEXTLINE(cert-msc30-c, cert-msc50-cpp)
    uint16_t command = std::rand() % UINT16_MAX;
    // NOLINTNEXTLINE(cert-msc30-c, cert-msc50-cpp)
    uint8_t command_version = std::rand() % UINT8_MAX;

    constexpr size_t kRequestSize = 15;

    std::vector<uint8_t> request(kRequestSize);
    std::generate(request.begin(), request.end(), std::rand);

    ReqHeader request_header;

    ASSERT_NO_THROW(populateReqHeader(command, command_version, request.data(),
                                      request.size(), &request_header));

    EXPECT_EQ(request_header.struct_version, SUPPORTED_STRUCT_VERSION);
    EXPECT_EQ(request_header.command, command);
    EXPECT_EQ(request_header.command_version, command_version);
    EXPECT_EQ(request_header.reserved, 0);
    EXPECT_EQ(request_header.data_len, request.size());

    EXPECT_EQ(0, calculateChecksum(
                     stdplus::raw::asSpan<uint8_t>(request_header), request));
}

TEST(PopulateReqHeader, PopulateRequestHeaderWithNullRequestData)
{
    ReqHeader request_header;

    ASSERT_NO_THROW(populateReqHeader(0xa5, 12, nullptr, 0, &request_header));
    EXPECT_EQ(0,
              calculateChecksum(stdplus::raw::asSpan<uint8_t>(request_header)));
}
