// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "google3/host_commands.h"

#include "host_command.hpp"
#include "message_mock.hpp"
#include "message_util.hpp"
#include "payload_update.hpp"

#include <boost/asio.hpp>
#include <stdplus/raw.hpp>
#include <xyz/openbmc_project/Common/error.hpp>
#include <xyz/openbmc_project/Control/Hoth/error.hpp>

#include <chrono>
#include <future>
#include <string>
#include <thread>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>

using namespace std::literals;
using namespace sdbusplus::error::xyz::openbmc_project::control::hoth;
using namespace sdbusplus::error::xyz::openbmc_project::common;
using namespace google::hoth;
using namespace google::hoth::internal;

using ::testing::_;
using ::testing::ContainerEq;
using ::testing::DoAll;
using ::testing::Return;
using ::testing::Throw;

namespace
{

const std::string kEndpointTestName = "test";
constexpr int kTestAsyncWaitTimeInSeconds = 1;
class HostCommandImplMock : public HostCommandImpl
{
  public:
    HostCommandImplMock(MessageIntf* msg, boost::asio::io_context* io,
                        LogCollectorUtil* logCollectorUtil,
                        const std::string& name, size_t max_retries = 0,
                        bool allowLegacyVerify = true,
                        const uint32_t uartChannelId = 0) :
        HostCommandImpl(msg, io, logCollectorUtil, name, max_retries,
                        allowLegacyVerify, uartChannelId)
    {}
    MOCK_METHOD1(sendCommand,
                 std::vector<uint8_t>(const std::vector<uint8_t>&));
    MOCK_METHOD2(sendCommand, std::vector<uint8_t>(const std::vector<uint8_t>&,
                                                   std::chrono::milliseconds));
    int callParentAsynHothCollectorFuture(bool cleanupPromiseAfterExecution)
    {
        return HostCommandImpl::collectHothLogsAsync(
            cleanupPromiseAfterExecution);
    }

    void callParentScheduler()
    {
        return HostCommandImpl::collectUartLogsAsync();
    }

    void stopParentScheduler()
    {
        HostCommandImpl::stopUartLogs();
    }
};

TEST(HothLogCollector, TriggerLogCollector)
{
    testing::StrictMock<internal::MessageMock> msg;
    boost::asio::io_context io;
    RateLimiter rateLimiter(kRateLimiterMilliSeconds);
    LogCollectorUtil logCollectorUtil(rateLimiter, kTestAsyncWaitTimeInSeconds);
    HostCommandImplMock hostCommandMock(&msg, &io, &logCollectorUtil,
                                        kEndpointTestName, 0, true);
    std::vector<uint8_t> resp = {0x03, 0xFD, 0x00, 0x00,
                                 0x00, 0x00, 0x00, 0x00};
    EXPECT_CALL(hostCommandMock, sendCommand(_))
        .Times(2)
        .WillRepeatedly(Return(resp));
    int requestId = hostCommandMock.callParentAsynHothCollectorFuture(false);
    io.run();
    hostCommandMock.waitOnLogCollectorPromise(requestId);
    io.stop();
}

TEST(UartLogCollector, TriggerLogCollector)
{
    testing::StrictMock<internal::MessageMock> msg;
    boost::asio::io_context io;
    RateLimiter rateLimiter(kRateLimiterMilliSeconds);
    LogCollectorUtil logCollectorUtil(rateLimiter, 1);
    std::string t = "test";
    HostCommandImplMock hostCommandMock(&msg, &io, &logCollectorUtil, t, 0,
                                        true);
    std::vector<uint8_t> writechannelOffsetResponse = {
        0x03, 0x37, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x52, 0x44, 0x11, 0x1b};
    std::vector<uint8_t> response = {
        0x03, 0x2a, 0x00, 0x00, 0xf8, 0x03, 0x00, 0x00, 0xda, 0x84, 0x30, 0x1b,
        0x78, 0x3a, 0x38, 0x0a, 0x5b, 0x4d, 0x33, 0x30, 0x2d, 0x36, 0x36, 0x36,
        0x42, 0x45, 0x43, 0x34, 0x41, 0x2d, 0x45, 0x5d, 0x20, 0x4d, 0x3a, 0x20,
        0x45, 0x52, 0x52, 0x20, 0x53, 0x4d, 0x42, 0x43, 0x4e, 0x3a, 0x33, 0x35,
        0x20, 0x0a, 0x5b, 0x4d, 0x33, 0x31, 0x2d, 0x36, 0x36, 0x36, 0x42, 0x46,
        0x45, 0x34, 0x37, 0x2d, 0x45, 0x5d, 0x20, 0x4d, 0x3a, 0x20, 0x45, 0x52,
        0x52, 0x20, 0x53, 0x4d, 0x42, 0x43, 0x4e, 0x3a, 0x33, 0x35, 0x20, 0x0a,
        0x5b, 0x4d, 0x33, 0x31, 0x2d, 0x36, 0x36, 0x36, 0x43, 0x30, 0x36, 0x30,
        0x46, 0x2d, 0x45, 0x5d, 0x20, 0x4d, 0x3a, 0x20, 0x45, 0x52, 0x52, 0x20,
        0x53, 0x4d, 0x42, 0x43, 0x4e, 0x3a, 0x33, 0x35, 0x20, 0x0a, 0x5b, 0x4d,
        0x33, 0x30, 0x2d, 0x36, 0x36, 0x36, 0x43, 0x30, 0x36, 0x31, 0x37, 0x2d,
        0x45, 0x5d, 0x20, 0x4d, 0x3a, 0x20, 0x45, 0x52, 0x52, 0x20, 0x53, 0x4d,
        0x42, 0x43, 0x4e, 0x3a, 0x33, 0x35, 0x20, 0x0a, 0x5b, 0x4d, 0x33, 0x30,
        0x2d, 0x36, 0x36, 0x36, 0x43, 0x30, 0x36, 0x33, 0x30, 0x2d, 0x4e, 0x5d,
        0x20, 0x67, 0x73, 0x74, 0x5f, 0x69, 0x6e, 0x6a, 0x5f, 0x73, 0x71, 0x65,
        0x5f, 0x70, 0x72, 0x6f, 0x63, 0x5f, 0x63, 0x74, 0x78, 0x2e, 0x74, 0x61,
        0x69, 0x6c, 0x3a, 0x31, 0x34, 0x2c, 0x20, 0x68, 0x65, 0x61, 0x64, 0x3a,
        0x31, 0x34, 0x0a, 0x5b, 0x4d, 0x33, 0x30, 0x2d, 0x36, 0x36, 0x36, 0x43,
        0x30, 0x36, 0x34, 0x32, 0x2d, 0x41, 0x5d, 0x20, 0x4d, 0x61, 0x73, 0x74,
        0x65, 0x72, 0x20, 0x41, 0x52, 0x42, 0x4c, 0x4f, 0x53, 0x54, 0x20, 0x3a,
        0x35, 0x0a, 0x5b, 0x4d, 0x33, 0x30, 0x2d, 0x36, 0x36, 0x36, 0x43, 0x30,
        0x36, 0x34, 0x32, 0x2d, 0x45, 0x5d, 0x20, 0x4d, 0x3a, 0x20, 0x45, 0x52,
        0x52, 0x3a, 0x20, 0x49, 0x4e, 0x43, 0x20, 0x62, 0x61, 0x63, 0x6b, 0x3a,
        0x34, 0x30, 0x2c, 0x20, 0x6e, 0x4e, 0x3a, 0x20, 0x32, 0x0a, 0};
    EXPECT_CALL(hostCommandMock, sendCommand(_, _))
        .WillOnce(Return(writechannelOffsetResponse))
        .WillRepeatedly(Return(response));
    hostCommandMock.callParentScheduler();
    io.run_one();
    hostCommandMock.stopParentScheduler();
    io.stop();
}

TEST(UartLogCollector, GetChannelWriteOffsetThrowsException)
{
    testing::StrictMock<internal::MessageMock> msg;
    boost::asio::io_context io;
    RateLimiter rateLimiter(kRateLimiterMilliSeconds);
    LogCollectorUtil logCollectorUtil(rateLimiter, 1);
    std::string t = "test";
    HostCommandImplMock hostCommandMock(&msg, &io, &logCollectorUtil, t, 0,
                                        true);
    // Only one sendCommand invoked and call returned after error encountered
    // The exception is handled appropriately with no crashes
    EXPECT_CALL(hostCommandMock, sendCommand(_, _))
        .WillOnce(::testing::Throw(InterfaceError()))  // Initial failure
        .WillOnce(::testing::Throw(InterfaceError()))  // Retry 1
        .WillOnce(::testing::Throw(InterfaceError()))  // Retry 2
        .WillOnce(::testing::Throw(InterfaceError())); // Retry 3
    hostCommandMock.callParentScheduler();
    io.run_one();
    io.run_one();
    io.run_one();
    hostCommandMock.stopParentScheduler();
    io.stop();
}

TEST(UartLogCollector, LogBufferRetrievalThrowsException)
{
    testing::StrictMock<internal::MessageMock> msg;
    boost::asio::io_context io;
    RateLimiter rateLimiter(kRateLimiterMilliSeconds);
    LogCollectorUtil logCollectorUtil(rateLimiter, 1);
    std::string t = "test";
    HostCommandImplMock hostCommandMock(&msg, &io, &logCollectorUtil, t, 0,
                                        true);

    std::vector<uint8_t> writechannelOffsetResponse = {
        0x03, 0x37, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x52, 0x44, 0x11, 0x1b};
    // The first sendHostCommand to get channel write offset passes but the
    // command to get first log buffer throws exception and is handled
    // appropriately with no crashes
    EXPECT_CALL(hostCommandMock, sendCommand(_, _))
        .WillOnce(Return(writechannelOffsetResponse))
        .WillOnce(::testing::Throw(InterfaceError()));
    hostCommandMock.callParentScheduler();
    io.run_one();
    hostCommandMock.stopParentScheduler();
    io.stop();
}

auto fillPayloadReq(uint8_t type, uint32_t offset = 0, uint32_t len = 0)
{
    std::array<uint8_t, sizeof(ReqHeader) + sizeof(payload_update_packet)> ret;
    auto req_view = stdplus::raw::asSpan<uint8_t>(ret);
    auto& hdr = stdplus::raw::extractRef<ReqHeader>(req_view);
    hdr.struct_version = SUPPORTED_STRUCT_VERSION;
    hdr.command = EC_CMD_BOARD_SPECIFIC_BASE + EC_PRV_CMD_HOTH_PAYLOAD_UPDATE;
    auto& req = stdplus::raw::extractRef<payload_update_packet>(req_view);
    req.type = type;
    req.offset = offset;
    req.len = len;
    return ret;
}

TEST(IsCommandLongRunningTest, ShortRequest)
{
    auto req_buf = fillPayloadReq(PAYLOAD_UPDATE_INITIATE);
    auto small = std::span<uint8_t>(req_buf).subspan(0, req_buf.size() - 1);
    EXPECT_EQ(EC_RES_SUCCESS,
              HostCommandImpl::isCommandLongRunning(small, true));
}

TEST(IsCommandLongRunningTest, NonPayload)
{
    auto req_buf = fillPayloadReq(PAYLOAD_UPDATE_INITIATE);
    auto& hdr = stdplus::raw::refFrom<ReqHeader>(req_buf);
    hdr.command = 0;
    EXPECT_EQ(EC_RES_SUCCESS,
              HostCommandImpl::isCommandLongRunning(req_buf, true));
}

TEST(IsCommandLongRunningTest, Initiate)
{
    auto req_buf = fillPayloadReq(PAYLOAD_UPDATE_INITIATE);
    EXPECT_EQ(EC_RES_ACCESS_DENIED,
              HostCommandImpl::isCommandLongRunning(req_buf, true));
}

TEST(IsCommandLongRunningTest, Finalize)
{
    auto req_buf = fillPayloadReq(PAYLOAD_UPDATE_FINALIZE);
    EXPECT_EQ(EC_RES_ACCESS_DENIED,
              HostCommandImpl::isCommandLongRunning(req_buf, true));
}

TEST(IsCommandLongRunningTest, VerifyAllowed)
{
    auto req_buf = fillPayloadReq(PAYLOAD_UPDATE_VERIFY);
    EXPECT_EQ(EC_RES_SUCCESS,
              HostCommandImpl::isCommandLongRunning(req_buf, true));
}

TEST(IsCommandLongRunningTest, VerifyBanned)
{
    auto req_buf = fillPayloadReq(PAYLOAD_UPDATE_VERIFY);
    EXPECT_EQ(EC_RES_ACCESS_DENIED,
              HostCommandImpl::isCommandLongRunning(req_buf, false));
}

TEST(IsCommandLongRunningTest, ShortErase)
{
    auto req_buf = fillPayloadReq(PAYLOAD_UPDATE_ERASE, 1 << 24, 10);
    EXPECT_EQ(EC_RES_SUCCESS,
              HostCommandImpl::isCommandLongRunning(req_buf, true));
}

TEST(IsCommandLongRunningTest, LongErase)
{
    auto req_buf = fillPayloadReq(PAYLOAD_UPDATE_ERASE, 0,
                                  std::numeric_limits<uint32_t>::max());
    EXPECT_EQ(EC_RES_ACCESS_DENIED,
              HostCommandImpl::isCommandLongRunning(req_buf, true));
}

TEST(IsCommandLongRunningTest, ShortVerifyChunk)
{
    auto req_buf = fillPayloadReq(PAYLOAD_UPDATE_VERIFY_CHUNK, 1 << 24, 10);
    EXPECT_EQ(EC_RES_INVALID_PARAM,
              HostCommandImpl::isCommandLongRunning(req_buf, true));
}

TEST(IsCommandLongRunningTest, LongVerifyChunk)
{
    auto req_buf = fillPayloadReq(PAYLOAD_UPDATE_VERIFY_CHUNK, 0,
                                  std::numeric_limits<uint32_t>::max());
    EXPECT_EQ(EC_RES_INVALID_PARAM,
              HostCommandImpl::isCommandLongRunning(req_buf, true));
}

auto static const hello_req_str =
    "\x03\xb6\x01\x00\x00\x00\x04\x00\x42\x00\x00\x00"s;
auto static const hello_rsp_str =
    "\x03\xad\x00\x00\x04\x00\x00\x00\x46\x03\x02\x01"s;
std::vector<uint8_t> static const hello_req(hello_req_str.begin(),
                                            hello_req_str.end());
std::vector<uint8_t> static const hello_rsp(hello_rsp_str.begin(),
                                            hello_rsp_str.end());

class BufError : public std::exception
{
  public:
    const char* what() const noexcept override
    {
        return "Not enough buffer";
    }
};

ACTION_P(BufCopy, tgt)
{
    if (arg2 + arg1 > tgt.size())
    {
        throw BufError();
    }
    memcpy(arg0, tgt.data() + arg2, arg1);
}

ACTION_P(BufMatch, tgt)
{
    ASSERT_LE(arg2 + arg1, tgt.size());
    EXPECT_EQ(0, memcmp(arg0, tgt.data() + arg2, arg1));
}

boost::asio::io_context io; // Global test io context for async timer
class HostCommandTest : public ::testing::Test
{
  protected:
    HostCommandTest() :
        rateLimiter(kRateLimiterMilliSeconds),
        logCollectorUtil(rateLimiter, kTestAsyncWaitTimeInSeconds),
        hostCmd(&msg, &io, &logCollectorUtil, kEndpointTestName, 0, true)
    {
        RateLimiter rateLimiter(kRateLimiterMilliSeconds);
        LogCollectorUtil logCollectorUtil(rateLimiter,
                                          kTestAsyncWaitTimeInSeconds);
    }

    // Message interface handle
    internal::MessageMock msg;

    // Host Command interface handle
    RateLimiter rateLimiter;
    LogCollectorUtil logCollectorUtil;
    internal::HostCommandImpl hostCmd;
};

TEST_F(HostCommandTest, syncBlocked)
{
    std::vector<uint8_t> req_buf(sizeof(ReqHeader) +
                                 sizeof(payload_update_packet));
    auto req_view = stdplus::raw::asSpan<uint8_t>(req_buf);
    auto& hdr = stdplus::raw::extractRef<ReqHeader>(req_view);
    hdr.struct_version = SUPPORTED_STRUCT_VERSION;
    hdr.command = EC_CMD_BOARD_SPECIFIC_BASE + EC_PRV_CMD_HOTH_PAYLOAD_UPDATE;
    auto& req = stdplus::raw::extractRef<payload_update_packet>(req_view);
    req.type = PAYLOAD_UPDATE_INITIATE;
    const std::vector<uint8_t> rsp = {3, 249, 4, 0, 0, 0, 0, 0};
    EXPECT_THAT(hostCmd.sendCommand(req_buf), ContainerEq(rsp));
    EXPECT_FALSE(hostCmd.communicationFailure());
}

TEST_F(HostCommandTest, syncBasic)
{
    EXPECT_CALL(msg, send(_, _, _)).WillRepeatedly(BufMatch(hello_req));
    EXPECT_CALL(msg, recv(_, _, _)).WillRepeatedly(BufCopy(hello_rsp));

    EXPECT_THAT(hostCmd.sendCommand(hello_req), ContainerEq(hello_rsp));
    EXPECT_FALSE(hostCmd.communicationFailure());
}

TEST_F(HostCommandTest, syncEmptyReq)
{
    const std::vector<uint8_t> req;

    EXPECT_THROW(hostCmd.sendCommand(req), CommandFailure);
    EXPECT_FALSE(hostCmd.communicationFailure());
}

TEST_F(HostCommandTest, syncShortRsp)
{
    std::vector<uint8_t> rsp = hello_rsp;
    // fix checksum
    rsp[1] += rsp[rsp.size() - 1];
    rsp.resize(rsp.size() - 1);

    EXPECT_CALL(msg, send(_, _, _)).WillRepeatedly(BufMatch(hello_req));
    EXPECT_CALL(msg, recv(_, _, _)).WillRepeatedly(BufCopy(rsp));

    EXPECT_THROW(hostCmd.sendCommand(hello_req), InterfaceError);
    EXPECT_TRUE(hostCmd.communicationFailure());
}

TEST_F(HostCommandTest, syncOffsetReq)
{
    std::vector<uint8_t> req = hello_req;

    req[req.size() - 1] += 5;

    EXPECT_THROW(hostCmd.sendCommand(req), CommandFailure);
    EXPECT_FALSE(hostCmd.communicationFailure());
}

TEST_F(HostCommandTest, syncOffsetRsp)
{
    std::vector<uint8_t> rsp = hello_rsp;

    rsp[rsp.size() - 1] += 5;

    EXPECT_CALL(msg, send(_, _, _)).WillRepeatedly(BufMatch(hello_req));
    EXPECT_CALL(msg, recv(_, _, _)).WillRepeatedly(BufCopy(rsp));

    EXPECT_THROW(hostCmd.sendCommand(hello_req), ResponseFailure);
    EXPECT_TRUE(hostCmd.communicationFailure());
}

TEST_F(HostCommandTest, syncTrailingReq)
{
    std::vector<uint8_t> req = hello_req;

    req.resize(req.size() + 5, 42);

    EXPECT_CALL(msg, send(_, _, _)).WillRepeatedly(BufMatch(req));
    EXPECT_CALL(msg, recv(_, _, _)).WillRepeatedly(BufCopy(hello_rsp));

    EXPECT_THAT(hostCmd.sendCommand(req), ContainerEq(hello_rsp));
    EXPECT_FALSE(hostCmd.communicationFailure());
}

TEST_F(HostCommandTest, syncTrailingRsp)
{
    std::vector<uint8_t> rsp = hello_rsp;

    rsp.resize(rsp.size() + 5, 42);

    EXPECT_CALL(msg, send(_, _, _)).WillRepeatedly(BufMatch(hello_req));
    EXPECT_CALL(msg, recv(_, _, _)).WillRepeatedly(BufCopy(rsp));

    EXPECT_THAT(hostCmd.sendCommand(hello_req), ContainerEq(hello_rsp));
    EXPECT_FALSE(hostCmd.communicationFailure());
}

TEST_F(HostCommandTest, syncSendFail)
{
    EXPECT_CALL(msg, send(_, _, _)).WillOnce(Throw(std::exception()));

    EXPECT_THROW(hostCmd.sendCommand(hello_req), InterfaceError);
    EXPECT_TRUE(hostCmd.communicationFailure());
}

TEST_F(HostCommandTest, syncRecvFail)
{
    EXPECT_CALL(msg, send(_, _, _)).WillRepeatedly(BufMatch(hello_req));
    EXPECT_CALL(msg, recv(_, _, _)).WillOnce(Throw(std::exception()));

    EXPECT_THROW(hostCmd.sendCommand(hello_req), InterfaceError);
    EXPECT_TRUE(hostCmd.communicationFailure());
}

TEST_F(HostCommandTest, syncWrongStructVersionReq)
{
    std::vector<uint8_t> req = hello_req;

    // struct_version is the first field and is 1 byte
    req[0] += 5;
    // Change checksum to match
    req[1] -= 5;

    EXPECT_THROW(hostCmd.sendCommand(req), CommandFailure);
}

TEST_F(HostCommandTest, syncWrongStructVersionRsp)
{
    std::vector<uint8_t> rsp = hello_rsp;

    // struct_version is the first field and is 1 byte
    rsp[0] += 5;
    // Change checksum to match
    rsp[1] -= 5;

    EXPECT_CALL(msg, send(_, _, _)).WillRepeatedly(BufMatch(hello_req));
    EXPECT_CALL(msg, recv(_, _, _)).WillRepeatedly(BufCopy(rsp));

    EXPECT_THROW(hostCmd.sendCommand(hello_req), ResponseFailure);
}

ACTION_P(Resume, promise)
{
    promise->set_value();
}

ACTION_P(WaitOn, future)
{
    future->wait();
}

TEST_F(HostCommandTest, requiresLock)
{
    std::promise<void> p1, p2, p3;
    std::future<void> f1 = p1.get_future(), f2 = p2.get_future(),
                      f3 = p3.get_future();
    {
        testing::InSequence seq;
        EXPECT_CALL(msg, send(_, _, _))
            .WillOnce(DoAll(BufMatch(hello_req), Resume(&p1), WaitOn(&f2)));
        EXPECT_CALL(msg, recv(_, _, _)).WillRepeatedly(BufCopy(hello_rsp));
        EXPECT_CALL(msg, send(_, _, _))
            .WillOnce(DoAll(BufMatch(hello_req), Resume(&p3)));
        EXPECT_CALL(msg, recv(_, _, _)).WillRepeatedly(BufCopy(hello_rsp));
    }
    auto h1 =
        std::async(std::launch::async, [&] { hostCmd.sendCommand(hello_req); });
    // Ensure the first task claims the lock
    ASSERT_EQ(std::future_status::ready, f1.wait_for(1s));
    f1.get();

    auto h2 =
        std::async(std::launch::async, [&] { hostCmd.sendCommand(hello_req); });
    // Ensure the second task gets to its locking region without running recv
    ASSERT_EQ(std::future_status::timeout, f3.wait_for(1s));

    // Neither thread should be done yet
    EXPECT_EQ(std::future_status::timeout, h1.wait_for(0s));
    EXPECT_EQ(std::future_status::timeout, h2.wait_for(0s));

    // Both should quickly complete now
    p2.set_value();
    ASSERT_EQ(std::future_status::ready, f3.wait_for(1s));
    f3.get();
    ASSERT_EQ(std::future_status::ready, h1.wait_for(1s));
    h1.get();
    ASSERT_EQ(std::future_status::ready, h2.wait_for(1s));
    h2.get();
}

TEST_F(HostCommandTest, timeoutLock)
{
    std::promise<void> p1, p2;
    std::future<void> f1 = p1.get_future(), f2 = p2.get_future();
    {
        testing::InSequence seq;
        EXPECT_CALL(msg, send(_, _, _))
            .WillOnce(DoAll(BufMatch(hello_req), Resume(&p1), WaitOn(&f2)));
        EXPECT_CALL(msg, recv(_, _, _)).WillRepeatedly(BufCopy(hello_rsp));
    }
    auto h1 = std::async(std::launch::async,
                         [&] { hostCmd.sendCommand(hello_req, 100ms); });
    // Ensure the first task claims the lock
    ASSERT_EQ(std::future_status::ready, f1.wait_for(1s));
    f1.get();

    auto h2 = std::async(std::launch::async,
                         [&] { hostCmd.sendCommand(hello_req, 100ms); });
    // Ensure the second thread hits a timeout on the lock
    ASSERT_EQ(std::future_status::ready, h2.wait_for(1s));
    EXPECT_THROW(h2.get(), Timeout);

    // The first thread should still complete
    EXPECT_EQ(std::future_status::timeout, h1.wait_for(0s));
    p2.set_value();
    ASSERT_EQ(std::future_status::ready, h1.wait_for(1s));
    h1.get();
}

class OverloadedHostCommandTest : public HostCommandTest
{
  protected:
    // This is the deconstructed version of hello_req_str defined above
    const uint16_t hello_struct_command = 1;
    const uint8_t hello_command_version = 0;
    std::vector<uint8_t> hello_command_request{0x42, 0, 0, 0};
};

TEST_F(OverloadedHostCommandTest, sendCommandSuccessfully)
{
    // HostCommandTest.syncBasic test but with the overloaded sendCommand

    // We expect msg->send mock to receive hello_req as a parameter even though
    // the request was constructed by sendCommand
    EXPECT_CALL(msg, send(_, _, _)).WillRepeatedly(BufMatch(hello_req));
    EXPECT_CALL(msg, recv(_, _, _)).WillRepeatedly(BufCopy(hello_rsp));

    EXPECT_THAT(hostCmd.sendCommand(hello_struct_command, hello_command_version,
                                    hello_command_request.data(),
                                    hello_command_request.size()),
                ContainerEq(hello_rsp));
}

TEST_F(OverloadedHostCommandTest, sendEmptyReqSuccessfully)
{
    // Send the hello command with empty request (nullptr)
    auto static const empty_req_str = "\x03\xfc\x01\x00\x00\x00\x00\x00"s;
    std::vector<uint8_t> static const empty_req(empty_req_str.begin(),
                                                empty_req_str.end());

    EXPECT_CALL(msg, send(_, _, _)).WillRepeatedly(BufMatch(empty_req));
    EXPECT_CALL(msg, recv(_, _, _)).WillRepeatedly(BufCopy(hello_rsp));

    EXPECT_THAT(hostCmd.sendCommand(hello_struct_command, hello_command_version,
                                    nullptr, 0),
                ContainerEq(hello_rsp));
}

TEST_F(OverloadedHostCommandTest, sendEmptyReqWithNonZeroSizeFails)
{
    // Sending the hello command with empty request and non-zero request size
    // should fail due to populateReqHeader check

    EXPECT_THROW(hostCmd.sendCommand(hello_struct_command,
                                     hello_command_version, nullptr, 1),
                 CommandFailure);
}

TEST_F(OverloadedHostCommandTest, sendOverflowRequestFails)
{
    // Sending the hello command with request size bigger than UINT16_MAX should
    // fail due to populateReqHeader check
    std::vector<uint8_t> static const overflow_req(UINT16_MAX + 1);

    EXPECT_THROW(hostCmd.sendCommand(hello_struct_command,
                                     hello_command_version, overflow_req.data(),
                                     overflow_req.size()),
                 CommandFailure);
}

TEST(HostCommandRetryTest, syncRetry)
{
    // Ensures that a failure in sending a command will be retried at least once
    // by having the first send to the hoth mailbox fail, then allowing the
    // subsequent send / recv to succeed like normal.
    internal::MessageMock msg;
    boost::asio::io_context io;
    RateLimiter rateLimiter(kRateLimiterMilliSeconds);
    LogCollectorUtil logCollectorUtil(rateLimiter, kTestAsyncWaitTimeInSeconds);
    internal::HostCommandImpl hostCmd(&msg, &io, &logCollectorUtil,
                                      kEndpointTestName, 1, true);

    {
        ::testing::InSequence seq;
        EXPECT_CALL(msg, send(_, _, _)).WillOnce(Throw(std::exception()));
        EXPECT_CALL(msg, send(_, _, _)).WillRepeatedly(BufMatch(hello_req));
        EXPECT_CALL(msg, recv(_, _, _)).WillRepeatedly(BufCopy(hello_rsp));
    }

    EXPECT_THAT(hostCmd.sendCommand(hello_req), ContainerEq(hello_rsp));
    EXPECT_FALSE(hostCmd.communicationFailure());
}

TEST(HostCommandRetryTest, syncRetryFail)
{
    // Ensures that a failure in sending a command will be retried only for the
    // first set of failures.
    testing::StrictMock<internal::MessageMock> msg;
    boost::asio::io_context io;
    RateLimiter rateLimiter(kRateLimiterMilliSeconds);
    LogCollectorUtil logCollectorUtil(rateLimiter, kTestAsyncWaitTimeInSeconds);
    internal::HostCommandImpl hostCmd(&msg, &io, &logCollectorUtil,
                                      kEndpointTestName, 2, true);
    EXPECT_FALSE(hostCmd.communicationFailure());

    EXPECT_CALL(msg, send(_, _, _))
        .Times(3)
        .WillRepeatedly(Throw(std::exception()));
    EXPECT_THROW(hostCmd.sendCommand(hello_req), std::exception);
    EXPECT_TRUE(hostCmd.communicationFailure());

    EXPECT_CALL(msg, send(_, _, _)).WillOnce(Throw(std::exception()));
    EXPECT_THROW(hostCmd.sendCommand(hello_req), std::exception);
    EXPECT_TRUE(hostCmd.communicationFailure());

    EXPECT_CALL(msg, send(_, _, _)).WillOnce(BufMatch(hello_req));
    EXPECT_CALL(msg, recv(_, _, _)).WillRepeatedly(BufCopy(hello_rsp));
    EXPECT_THAT(hostCmd.sendCommand(hello_req), ContainerEq(hello_rsp));
    EXPECT_FALSE(hostCmd.communicationFailure());
}

} // namespace
