blob: fe54040e57ace15193b1fc4a4faf6dc346729848 [file] [log] [blame]
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "google3/host_commands.h"
#include "host_command.hpp"
#include "message_mock.hpp"
#include "message_util.hpp"
#include "payload_update.hpp"
#include <boost/asio.hpp>
#include <stdplus/raw.hpp>
#include <xyz/openbmc_project/Common/error.hpp>
#include <xyz/openbmc_project/Control/Hoth/error.hpp>
#include <chrono>
#include <future>
#include <string>
#include <thread>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
using namespace std::literals;
using sdbusplus::error::xyz::openbmc_project::control::hoth::CommandFailure;
using sdbusplus::error::xyz::openbmc_project::control::hoth::InterfaceError;
using sdbusplus::error::xyz::openbmc_project::control::hoth::ResponseFailure;
using sdbusplus::error::xyz::openbmc_project::common::Timeout;
using google::hoth::MessageIntf;
using google::hoth::internal::EC_RES_ACCESS_DENIED;
using google::hoth::internal::EC_RES_SUCCESS;
using google::hoth::internal::EC_RES_INVALID_PARAM;
using google::hoth::internal::HostCommandImpl;
using google::hoth::internal::LogCollectorUtil;
using google::hoth::internal::MessageMock;
using google::hoth::internal::RateLimiter;
using google::hoth::internal::ReqHeader;
using google::hoth::internal::SUPPORTED_STRUCT_VERSION;
using ::testing::_;
using ::testing::ContainerEq;
using ::testing::DoAll;
using ::testing::Return;
using ::testing::Throw;
namespace
{
const std::string kEndpointTestName = "test";
constexpr int kTestAsyncWaitTimeInSeconds = 1;
class HostCommandImplMock : public HostCommandImpl
{
public:
HostCommandImplMock(MessageIntf* msg, boost::asio::io_context* io,
LogCollectorUtil* logCollectorUtil,
const std::string& name, size_t max_retries = 0,
bool allowLegacyVerify = true,
const uint32_t uartChannelId = 0) :
HostCommandImpl(msg, io, logCollectorUtil, name, max_retries,
allowLegacyVerify, uartChannelId)
{}
MOCK_METHOD1(sendCommand,
std::vector<uint8_t>(const std::vector<uint8_t>&));
MOCK_METHOD2(sendCommand, std::vector<uint8_t>(const std::vector<uint8_t>&,
std::chrono::milliseconds));
int callParentAsynHothCollectorFuture(bool cleanupPromiseAfterExecution)
{
return HostCommandImpl::collectHothLogsAsync(
cleanupPromiseAfterExecution);
}
void callParentScheduler()
{
return HostCommandImpl::collectUartLogsAsync();
}
void stopParentScheduler()
{
HostCommandImpl::stopUartLogs();
}
};
TEST(HothLogCollector, TriggerLogCollector)
{
testing::StrictMock<MessageMock> msg;
boost::asio::io_context io;
RateLimiter rateLimiter(kRateLimiterMilliSeconds);
LogCollectorUtil logCollectorUtil(rateLimiter, kTestAsyncWaitTimeInSeconds);
HostCommandImplMock hostCommandMock(&msg, &io, &logCollectorUtil,
kEndpointTestName, 0, true);
std::vector<uint8_t> resp = {0x03, 0xFD, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00};
EXPECT_CALL(hostCommandMock, sendCommand(_))
.Times(2)
.WillRepeatedly(Return(resp));
int requestId = hostCommandMock.callParentAsynHothCollectorFuture(false);
io.run();
hostCommandMock.waitOnLogCollectorPromise(requestId);
io.stop();
}
TEST(UartLogCollector, TriggerLogCollector)
{
testing::StrictMock<MessageMock> msg;
boost::asio::io_context io;
RateLimiter rateLimiter(kRateLimiterMilliSeconds);
LogCollectorUtil logCollectorUtil(rateLimiter, 1);
std::string t = "test";
HostCommandImplMock hostCommandMock(&msg, &io, &logCollectorUtil, t, 0,
true);
std::vector<uint8_t> writechannelOffsetResponse = {
0x03, 0x37, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x52, 0x44, 0x11, 0x1b};
std::vector<uint8_t> response = {
0x03, 0x2a, 0x00, 0x00, 0xf8, 0x03, 0x00, 0x00, 0xda, 0x84, 0x30, 0x1b,
0x78, 0x3a, 0x38, 0x0a, 0x5b, 0x4d, 0x33, 0x30, 0x2d, 0x36, 0x36, 0x36,
0x42, 0x45, 0x43, 0x34, 0x41, 0x2d, 0x45, 0x5d, 0x20, 0x4d, 0x3a, 0x20,
0x45, 0x52, 0x52, 0x20, 0x53, 0x4d, 0x42, 0x43, 0x4e, 0x3a, 0x33, 0x35,
0x20, 0x0a, 0x5b, 0x4d, 0x33, 0x31, 0x2d, 0x36, 0x36, 0x36, 0x42, 0x46,
0x45, 0x34, 0x37, 0x2d, 0x45, 0x5d, 0x20, 0x4d, 0x3a, 0x20, 0x45, 0x52,
0x52, 0x20, 0x53, 0x4d, 0x42, 0x43, 0x4e, 0x3a, 0x33, 0x35, 0x20, 0x0a,
0x5b, 0x4d, 0x33, 0x31, 0x2d, 0x36, 0x36, 0x36, 0x43, 0x30, 0x36, 0x30,
0x46, 0x2d, 0x45, 0x5d, 0x20, 0x4d, 0x3a, 0x20, 0x45, 0x52, 0x52, 0x20,
0x53, 0x4d, 0x42, 0x43, 0x4e, 0x3a, 0x33, 0x35, 0x20, 0x0a, 0x5b, 0x4d,
0x33, 0x30, 0x2d, 0x36, 0x36, 0x36, 0x43, 0x30, 0x36, 0x31, 0x37, 0x2d,
0x45, 0x5d, 0x20, 0x4d, 0x3a, 0x20, 0x45, 0x52, 0x52, 0x20, 0x53, 0x4d,
0x42, 0x43, 0x4e, 0x3a, 0x33, 0x35, 0x20, 0x0a, 0x5b, 0x4d, 0x33, 0x30,
0x2d, 0x36, 0x36, 0x36, 0x43, 0x30, 0x36, 0x33, 0x30, 0x2d, 0x4e, 0x5d,
0x20, 0x67, 0x73, 0x74, 0x5f, 0x69, 0x6e, 0x6a, 0x5f, 0x73, 0x71, 0x65,
0x5f, 0x70, 0x72, 0x6f, 0x63, 0x5f, 0x63, 0x74, 0x78, 0x2e, 0x74, 0x61,
0x69, 0x6c, 0x3a, 0x31, 0x34, 0x2c, 0x20, 0x68, 0x65, 0x61, 0x64, 0x3a,
0x31, 0x34, 0x0a, 0x5b, 0x4d, 0x33, 0x30, 0x2d, 0x36, 0x36, 0x36, 0x43,
0x30, 0x36, 0x34, 0x32, 0x2d, 0x41, 0x5d, 0x20, 0x4d, 0x61, 0x73, 0x74,
0x65, 0x72, 0x20, 0x41, 0x52, 0x42, 0x4c, 0x4f, 0x53, 0x54, 0x20, 0x3a,
0x35, 0x0a, 0x5b, 0x4d, 0x33, 0x30, 0x2d, 0x36, 0x36, 0x36, 0x43, 0x30,
0x36, 0x34, 0x32, 0x2d, 0x45, 0x5d, 0x20, 0x4d, 0x3a, 0x20, 0x45, 0x52,
0x52, 0x3a, 0x20, 0x49, 0x4e, 0x43, 0x20, 0x62, 0x61, 0x63, 0x6b, 0x3a,
0x34, 0x30, 0x2c, 0x20, 0x6e, 0x4e, 0x3a, 0x20, 0x32, 0x0a, 0};
EXPECT_CALL(hostCommandMock, sendCommand(_, _))
.WillOnce(Return(writechannelOffsetResponse))
.WillRepeatedly(Return(response));
hostCommandMock.callParentScheduler();
io.run_one();
hostCommandMock.stopParentScheduler();
io.stop();
}
TEST(UartLogCollector, GetChannelWriteOffsetThrowsException)
{
testing::StrictMock<MessageMock> msg;
boost::asio::io_context io;
RateLimiter rateLimiter(kRateLimiterMilliSeconds);
LogCollectorUtil logCollectorUtil(rateLimiter, 1);
std::string t = "test";
HostCommandImplMock hostCommandMock(&msg, &io, &logCollectorUtil, t, 0,
true);
// Only one sendCommand invoked and call returned after error encountered
// The exception is handled appropriately with no crashes
EXPECT_CALL(hostCommandMock, sendCommand(_, _))
.WillOnce(::testing::Throw(InterfaceError())) // Initial failure
.WillOnce(::testing::Throw(InterfaceError())) // Retry 1
.WillOnce(::testing::Throw(InterfaceError())) // Retry 2
.WillOnce(::testing::Throw(InterfaceError())); // Retry 3
hostCommandMock.callParentScheduler();
io.run_one();
io.run_one();
io.run_one();
hostCommandMock.stopParentScheduler();
io.stop();
}
TEST(UartLogCollector, LogBufferRetrievalThrowsException)
{
testing::StrictMock<MessageMock> msg;
boost::asio::io_context io;
RateLimiter rateLimiter(kRateLimiterMilliSeconds);
LogCollectorUtil logCollectorUtil(rateLimiter, 1);
std::string t = "test";
HostCommandImplMock hostCommandMock(&msg, &io, &logCollectorUtil, t, 0,
true);
std::vector<uint8_t> writechannelOffsetResponse = {
0x03, 0x37, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x52, 0x44, 0x11, 0x1b};
// The first sendHostCommand to get channel write offset passes but the
// command to get first log buffer throws exception and is handled
// appropriately with no crashes
EXPECT_CALL(hostCommandMock, sendCommand(_, _))
.WillOnce(Return(writechannelOffsetResponse))
.WillOnce(::testing::Throw(InterfaceError()));
hostCommandMock.callParentScheduler();
io.run_one();
hostCommandMock.stopParentScheduler();
io.stop();
}
auto fillPayloadReq(uint8_t type, uint32_t offset = 0, uint32_t len = 0)
{
std::array<uint8_t, sizeof(ReqHeader) + sizeof(payload_update_packet)> ret;
auto req_view = stdplus::raw::asSpan<uint8_t>(ret);
auto& hdr = stdplus::raw::extractRef<ReqHeader>(req_view);
hdr.struct_version = SUPPORTED_STRUCT_VERSION;
hdr.command = EC_CMD_BOARD_SPECIFIC_BASE + EC_PRV_CMD_HOTH_PAYLOAD_UPDATE;
auto& req = stdplus::raw::extractRef<payload_update_packet>(req_view);
req.type = type;
req.offset = offset;
req.len = len;
return ret;
}
TEST(IsCommandLongRunningTest, ShortRequest)
{
auto req_buf = fillPayloadReq(PAYLOAD_UPDATE_INITIATE);
auto small = std::span<uint8_t>(req_buf).subspan(0, req_buf.size() - 1);
EXPECT_EQ(EC_RES_SUCCESS,
HostCommandImpl::isCommandLongRunning(small, true));
}
TEST(IsCommandLongRunningTest, NonPayload)
{
auto req_buf = fillPayloadReq(PAYLOAD_UPDATE_INITIATE);
auto& hdr = stdplus::raw::refFrom<ReqHeader>(req_buf);
hdr.command = 0;
EXPECT_EQ(EC_RES_SUCCESS,
HostCommandImpl::isCommandLongRunning(req_buf, true));
}
TEST(IsCommandLongRunningTest, Initiate)
{
auto req_buf = fillPayloadReq(PAYLOAD_UPDATE_INITIATE);
EXPECT_EQ(EC_RES_ACCESS_DENIED,
HostCommandImpl::isCommandLongRunning(req_buf, true));
}
TEST(IsCommandLongRunningTest, Finalize)
{
auto req_buf = fillPayloadReq(PAYLOAD_UPDATE_FINALIZE);
EXPECT_EQ(EC_RES_ACCESS_DENIED,
HostCommandImpl::isCommandLongRunning(req_buf, true));
}
TEST(IsCommandLongRunningTest, VerifyAllowed)
{
auto req_buf = fillPayloadReq(PAYLOAD_UPDATE_VERIFY);
EXPECT_EQ(EC_RES_SUCCESS,
HostCommandImpl::isCommandLongRunning(req_buf, true));
}
TEST(IsCommandLongRunningTest, VerifyBanned)
{
auto req_buf = fillPayloadReq(PAYLOAD_UPDATE_VERIFY);
EXPECT_EQ(EC_RES_ACCESS_DENIED,
HostCommandImpl::isCommandLongRunning(req_buf, false));
}
TEST(IsCommandLongRunningTest, ShortErase)
{
auto req_buf = fillPayloadReq(PAYLOAD_UPDATE_ERASE, 1 << 24, 10);
EXPECT_EQ(EC_RES_SUCCESS,
HostCommandImpl::isCommandLongRunning(req_buf, true));
}
TEST(IsCommandLongRunningTest, LongErase)
{
auto req_buf = fillPayloadReq(PAYLOAD_UPDATE_ERASE, 0,
std::numeric_limits<uint32_t>::max());
EXPECT_EQ(EC_RES_ACCESS_DENIED,
HostCommandImpl::isCommandLongRunning(req_buf, true));
}
TEST(IsCommandLongRunningTest, ShortVerifyChunk)
{
auto req_buf = fillPayloadReq(PAYLOAD_UPDATE_VERIFY_CHUNK, 1 << 24, 10);
EXPECT_EQ(EC_RES_INVALID_PARAM,
HostCommandImpl::isCommandLongRunning(req_buf, true));
}
TEST(IsCommandLongRunningTest, LongVerifyChunk)
{
auto req_buf = fillPayloadReq(PAYLOAD_UPDATE_VERIFY_CHUNK, 0,
std::numeric_limits<uint32_t>::max());
EXPECT_EQ(EC_RES_INVALID_PARAM,
HostCommandImpl::isCommandLongRunning(req_buf, true));
}
auto const hello_req_str = "\x03\xb6\x01\x00\x00\x00\x04\x00\x42\x00\x00\x00"s;
auto const hello_rsp_str = "\x03\xad\x00\x00\x04\x00\x00\x00\x46\x03\x02\x01"s;
std::vector<uint8_t> const hello_req(hello_req_str.begin(),
hello_req_str.end());
std::vector<uint8_t> const hello_rsp(hello_rsp_str.begin(),
hello_rsp_str.end());
class BufError : public std::exception
{
public:
const char* what() const noexcept override
{
return "Not enough buffer";
}
};
ACTION_P(BufCopy, tgt)
{
if (arg2 + arg1 > tgt.size())
{
throw BufError();
}
memcpy(arg0, tgt.data() + arg2, arg1);
}
ACTION_P(BufMatch, tgt)
{
ASSERT_LE(arg2 + arg1, tgt.size());
EXPECT_EQ(0, memcmp(arg0, tgt.data() + arg2, arg1));
}
boost::asio::io_context io; // Global test io context for async timer
class HostCommandTest : public ::testing::Test
{
protected:
HostCommandTest() :
rateLimiter(kRateLimiterMilliSeconds),
logCollectorUtil(rateLimiter, kTestAsyncWaitTimeInSeconds),
hostCmd(&msg, &io, &logCollectorUtil, kEndpointTestName, 0, true)
{
RateLimiter rateLimiter(kRateLimiterMilliSeconds);
LogCollectorUtil logCollectorUtil(rateLimiter,
kTestAsyncWaitTimeInSeconds);
}
// Message interface handle
MessageMock msg;
// Host Command interface handle
RateLimiter rateLimiter;
LogCollectorUtil logCollectorUtil;
HostCommandImpl hostCmd;
};
TEST_F(HostCommandTest, syncBlocked)
{
std::vector<uint8_t> req_buf(sizeof(ReqHeader) +
sizeof(payload_update_packet));
auto req_view = stdplus::raw::asSpan<uint8_t>(req_buf);
auto& hdr = stdplus::raw::extractRef<ReqHeader>(req_view);
hdr.struct_version = SUPPORTED_STRUCT_VERSION;
hdr.command = EC_CMD_BOARD_SPECIFIC_BASE + EC_PRV_CMD_HOTH_PAYLOAD_UPDATE;
auto& req = stdplus::raw::extractRef<payload_update_packet>(req_view);
req.type = PAYLOAD_UPDATE_INITIATE;
const std::vector<uint8_t> rsp = {3, 249, 4, 0, 0, 0, 0, 0};
EXPECT_THAT(hostCmd.sendCommand(req_buf), ContainerEq(rsp));
EXPECT_FALSE(hostCmd.communicationFailure());
}
TEST_F(HostCommandTest, syncBasic)
{
EXPECT_CALL(msg, send(_, _, _)).WillRepeatedly(BufMatch(hello_req));
EXPECT_CALL(msg, recv(_, _, _)).WillRepeatedly(BufCopy(hello_rsp));
EXPECT_THAT(hostCmd.sendCommand(hello_req), ContainerEq(hello_rsp));
EXPECT_FALSE(hostCmd.communicationFailure());
}
TEST_F(HostCommandTest, syncEmptyReq)
{
const std::vector<uint8_t> req;
EXPECT_THROW(hostCmd.sendCommand(req), CommandFailure);
EXPECT_FALSE(hostCmd.communicationFailure());
}
TEST_F(HostCommandTest, syncShortRsp)
{
std::vector<uint8_t> rsp = hello_rsp;
// fix checksum
rsp[1] += rsp[rsp.size() - 1];
rsp.resize(rsp.size() - 1);
EXPECT_CALL(msg, send(_, _, _)).WillRepeatedly(BufMatch(hello_req));
EXPECT_CALL(msg, recv(_, _, _)).WillRepeatedly(BufCopy(rsp));
EXPECT_THROW(hostCmd.sendCommand(hello_req), InterfaceError);
EXPECT_TRUE(hostCmd.communicationFailure());
}
TEST_F(HostCommandTest, syncOffsetReq)
{
std::vector<uint8_t> req = hello_req;
req[req.size() - 1] += 5;
EXPECT_THROW(hostCmd.sendCommand(req), CommandFailure);
EXPECT_FALSE(hostCmd.communicationFailure());
}
TEST_F(HostCommandTest, syncOffsetRsp)
{
std::vector<uint8_t> rsp = hello_rsp;
rsp[rsp.size() - 1] += 5;
EXPECT_CALL(msg, send(_, _, _)).WillRepeatedly(BufMatch(hello_req));
EXPECT_CALL(msg, recv(_, _, _)).WillRepeatedly(BufCopy(rsp));
EXPECT_THROW(hostCmd.sendCommand(hello_req), ResponseFailure);
EXPECT_TRUE(hostCmd.communicationFailure());
}
TEST_F(HostCommandTest, syncTrailingReq)
{
std::vector<uint8_t> req = hello_req;
req.resize(req.size() + 5, 42);
EXPECT_CALL(msg, send(_, _, _)).WillRepeatedly(BufMatch(req));
EXPECT_CALL(msg, recv(_, _, _)).WillRepeatedly(BufCopy(hello_rsp));
EXPECT_THAT(hostCmd.sendCommand(req), ContainerEq(hello_rsp));
EXPECT_FALSE(hostCmd.communicationFailure());
}
TEST_F(HostCommandTest, syncTrailingRsp)
{
std::vector<uint8_t> rsp = hello_rsp;
rsp.resize(rsp.size() + 5, 42);
EXPECT_CALL(msg, send(_, _, _)).WillRepeatedly(BufMatch(hello_req));
EXPECT_CALL(msg, recv(_, _, _)).WillRepeatedly(BufCopy(rsp));
EXPECT_THAT(hostCmd.sendCommand(hello_req), ContainerEq(hello_rsp));
EXPECT_FALSE(hostCmd.communicationFailure());
}
TEST_F(HostCommandTest, syncSendFail)
{
EXPECT_CALL(msg, send(_, _, _)).WillOnce(Throw(std::exception()));
EXPECT_THROW(hostCmd.sendCommand(hello_req), InterfaceError);
EXPECT_TRUE(hostCmd.communicationFailure());
}
TEST_F(HostCommandTest, syncRecvFail)
{
EXPECT_CALL(msg, send(_, _, _)).WillRepeatedly(BufMatch(hello_req));
EXPECT_CALL(msg, recv(_, _, _)).WillOnce(Throw(std::exception()));
EXPECT_THROW(hostCmd.sendCommand(hello_req), InterfaceError);
EXPECT_TRUE(hostCmd.communicationFailure());
}
TEST_F(HostCommandTest, syncWrongStructVersionReq)
{
std::vector<uint8_t> req = hello_req;
// struct_version is the first field and is 1 byte
req[0] += 5;
// Change checksum to match
req[1] -= 5;
EXPECT_THROW(hostCmd.sendCommand(req), CommandFailure);
}
TEST_F(HostCommandTest, syncWrongStructVersionRsp)
{
std::vector<uint8_t> rsp = hello_rsp;
// struct_version is the first field and is 1 byte
rsp[0] += 5;
// Change checksum to match
rsp[1] -= 5;
EXPECT_CALL(msg, send(_, _, _)).WillRepeatedly(BufMatch(hello_req));
EXPECT_CALL(msg, recv(_, _, _)).WillRepeatedly(BufCopy(rsp));
EXPECT_THROW(hostCmd.sendCommand(hello_req), ResponseFailure);
}
ACTION_P(Resume, promise)
{
promise->set_value();
}
ACTION_P(WaitOn, future)
{
future->wait();
}
TEST_F(HostCommandTest, requiresLock)
{
std::promise<void> p1, p2, p3;
std::future<void> f1 = p1.get_future(), f2 = p2.get_future(),
f3 = p3.get_future();
{
testing::InSequence seq;
EXPECT_CALL(msg, send(_, _, _))
.WillOnce(DoAll(BufMatch(hello_req), Resume(&p1), WaitOn(&f2)));
EXPECT_CALL(msg, recv(_, _, _)).WillRepeatedly(BufCopy(hello_rsp));
EXPECT_CALL(msg, send(_, _, _))
.WillOnce(DoAll(BufMatch(hello_req), Resume(&p3)));
EXPECT_CALL(msg, recv(_, _, _)).WillRepeatedly(BufCopy(hello_rsp));
}
auto h1 =
std::async(std::launch::async, [&] { hostCmd.sendCommand(hello_req); });
// Ensure the first task claims the lock
ASSERT_EQ(std::future_status::ready, f1.wait_for(1s));
f1.get();
auto h2 =
std::async(std::launch::async, [&] { hostCmd.sendCommand(hello_req); });
// Ensure the second task gets to its locking region without running recv
ASSERT_EQ(std::future_status::timeout, f3.wait_for(1s));
// Neither thread should be done yet
EXPECT_EQ(std::future_status::timeout, h1.wait_for(0s));
EXPECT_EQ(std::future_status::timeout, h2.wait_for(0s));
// Both should quickly complete now
p2.set_value();
ASSERT_EQ(std::future_status::ready, f3.wait_for(1s));
f3.get();
ASSERT_EQ(std::future_status::ready, h1.wait_for(1s));
h1.get();
ASSERT_EQ(std::future_status::ready, h2.wait_for(1s));
h2.get();
}
TEST_F(HostCommandTest, timeoutLock)
{
std::promise<void> p1, p2;
std::future<void> f1 = p1.get_future(), f2 = p2.get_future();
{
testing::InSequence seq;
EXPECT_CALL(msg, send(_, _, _))
.WillOnce(DoAll(BufMatch(hello_req), Resume(&p1), WaitOn(&f2)));
EXPECT_CALL(msg, recv(_, _, _)).WillRepeatedly(BufCopy(hello_rsp));
}
auto h1 = std::async(std::launch::async,
[&] { hostCmd.sendCommand(hello_req, 100ms); });
// Ensure the first task claims the lock
ASSERT_EQ(std::future_status::ready, f1.wait_for(1s));
f1.get();
auto h2 = std::async(std::launch::async,
[&] { hostCmd.sendCommand(hello_req, 100ms); });
// Ensure the second thread hits a timeout on the lock
ASSERT_EQ(std::future_status::ready, h2.wait_for(1s));
EXPECT_THROW(h2.get(), Timeout);
// The first thread should still complete
EXPECT_EQ(std::future_status::timeout, h1.wait_for(0s));
p2.set_value();
ASSERT_EQ(std::future_status::ready, h1.wait_for(1s));
h1.get();
}
class OverloadedHostCommandTest : public HostCommandTest
{
protected:
// This is the deconstructed version of hello_req_str defined above
const uint16_t hello_struct_command = 1;
const uint8_t hello_command_version = 0;
std::vector<uint8_t> hello_command_request{0x42, 0, 0, 0};
};
TEST_F(OverloadedHostCommandTest, sendCommandSuccessfully)
{
// HostCommandTest.syncBasic test but with the overloaded sendCommand
// We expect msg->send mock to receive hello_req as a parameter even though
// the request was constructed by sendCommand
EXPECT_CALL(msg, send(_, _, _)).WillRepeatedly(BufMatch(hello_req));
EXPECT_CALL(msg, recv(_, _, _)).WillRepeatedly(BufCopy(hello_rsp));
EXPECT_THAT(hostCmd.sendCommand(hello_struct_command, hello_command_version,
hello_command_request.data(),
hello_command_request.size()),
ContainerEq(hello_rsp));
}
TEST_F(OverloadedHostCommandTest, sendEmptyReqSuccessfully)
{
// Send the hello command with empty request (nullptr)
auto static const empty_req_str = "\x03\xfc\x01\x00\x00\x00\x00\x00"s;
std::vector<uint8_t> static const empty_req(empty_req_str.begin(),
empty_req_str.end());
EXPECT_CALL(msg, send(_, _, _)).WillRepeatedly(BufMatch(empty_req));
EXPECT_CALL(msg, recv(_, _, _)).WillRepeatedly(BufCopy(hello_rsp));
EXPECT_THAT(hostCmd.sendCommand(hello_struct_command, hello_command_version,
nullptr, 0),
ContainerEq(hello_rsp));
}
TEST_F(OverloadedHostCommandTest, sendEmptyReqWithNonZeroSizeFails)
{
// Sending the hello command with empty request and non-zero request size
// should fail due to populateReqHeader check
EXPECT_THROW(hostCmd.sendCommand(hello_struct_command,
hello_command_version, nullptr, 1),
CommandFailure);
}
TEST_F(OverloadedHostCommandTest, sendOverflowRequestFails)
{
// Sending the hello command with request size bigger than UINT16_MAX should
// fail due to populateReqHeader check
std::vector<uint8_t> static const overflow_req(UINT16_MAX + 1);
EXPECT_THROW(hostCmd.sendCommand(hello_struct_command,
hello_command_version, overflow_req.data(),
overflow_req.size()),
CommandFailure);
}
TEST(HostCommandRetryTest, syncRetry)
{
// Ensures that a failure in sending a command will be retried at least once
// by having the first send to the hoth mailbox fail, then allowing the
// subsequent send / recv to succeed like normal.
MessageMock msg;
boost::asio::io_context io;
RateLimiter rateLimiter(kRateLimiterMilliSeconds);
LogCollectorUtil logCollectorUtil(rateLimiter, kTestAsyncWaitTimeInSeconds);
HostCommandImpl hostCmd(&msg, &io, &logCollectorUtil,
kEndpointTestName, 1, true);
{
::testing::InSequence seq;
EXPECT_CALL(msg, send(_, _, _)).WillOnce(Throw(std::exception()));
EXPECT_CALL(msg, send(_, _, _)).WillRepeatedly(BufMatch(hello_req));
EXPECT_CALL(msg, recv(_, _, _)).WillRepeatedly(BufCopy(hello_rsp));
}
EXPECT_THAT(hostCmd.sendCommand(hello_req), ContainerEq(hello_rsp));
EXPECT_FALSE(hostCmd.communicationFailure());
}
TEST(HostCommandRetryTest, syncRetryFail)
{
// Ensures that a failure in sending a command will be retried only for the
// first set of failures.
testing::StrictMock<MessageMock> msg;
boost::asio::io_context io;
RateLimiter rateLimiter(kRateLimiterMilliSeconds);
LogCollectorUtil logCollectorUtil(rateLimiter, kTestAsyncWaitTimeInSeconds);
HostCommandImpl hostCmd(&msg, &io, &logCollectorUtil,
kEndpointTestName, 2, true);
EXPECT_FALSE(hostCmd.communicationFailure());
EXPECT_CALL(msg, send(_, _, _))
.Times(3)
.WillRepeatedly(Throw(std::exception()));
EXPECT_THROW(hostCmd.sendCommand(hello_req), std::exception);
EXPECT_TRUE(hostCmd.communicationFailure());
EXPECT_CALL(msg, send(_, _, _)).WillOnce(Throw(std::exception()));
EXPECT_THROW(hostCmd.sendCommand(hello_req), std::exception);
EXPECT_TRUE(hostCmd.communicationFailure());
EXPECT_CALL(msg, send(_, _, _)).WillOnce(BufMatch(hello_req));
EXPECT_CALL(msg, recv(_, _, _)).WillRepeatedly(BufCopy(hello_rsp));
EXPECT_THAT(hostCmd.sendCommand(hello_req), ContainerEq(hello_rsp));
EXPECT_FALSE(hostCmd.communicationFailure());
}
} // namespace