Add test infra to support state machine testing Following changes are included: * Change in the pldm interface to make it test friendly * Added test cases for the base discovery state machine Google-Bug-Id: 320779802 Change-Id: I51abc449295c348c06f761c6e00da6839ba40db3 Signed-off-by: Harsh Tyagi <harshtya@google.com>
diff --git a/interface/pldm_interface.cpp b/interface/pldm_interface.cpp index 7795a79..33e88ec 100644 --- a/interface/pldm_interface.cpp +++ b/interface/pldm_interface.cpp
@@ -10,9 +10,10 @@ constexpr uint8_t MCTP_MSG_TYPE_PLDM = 1; -pldm_requester_rc_t pldmSendAtNetwork(mctp_eid_t eid, int networkId, int mctpFd, - const uint8_t* pldmReqMsg, - size_t reqMsgLen) +pldm_requester_rc_t PldmInterface::pldmSendAtNetwork(mctp_eid_t eid, + int networkId, int mctpFd, + const uint8_t* pldmReqMsg, + size_t reqMsgLen) { struct sockaddr_mctp addr; addr.smctp_family = AF_MCTP; @@ -37,9 +38,10 @@ return PLDM_REQUESTER_SUCCESS; } -pldm_requester_rc_t mctpRecvAtNetwork(mctp_eid_t eid, int mctpFd, - uint8_t** pldmRespMsg, size_t* respMsgLen, - int networkId) +pldm_requester_rc_t PldmInterface::mctpRecvAtNetwork(mctp_eid_t eid, int mctpFd, + uint8_t** pldmRespMsg, + size_t* respMsgLen, + int networkId) { ssize_t minLen = sizeof(struct pldm_msg_hdr); struct sockaddr_mctp addr; @@ -87,9 +89,10 @@ return PLDM_REQUESTER_SUCCESS; } -pldm_requester_rc_t recvAtNetwork(mctp_eid_t eid, int mctpFd, - uint8_t** pldmRespMsg, size_t* respMsgLen, - int networkId) +pldm_requester_rc_t PldmInterface::recvAtNetwork(mctp_eid_t eid, int mctpFd, + uint8_t** pldmRespMsg, + size_t* respMsgLen, + int networkId) { pldm_requester_rc_t rc = mctpRecvAtNetwork(eid, mctpFd, pldmRespMsg, respMsgLen, networkId); @@ -127,9 +130,11 @@ return PLDM_REQUESTER_SUCCESS; } -pldm_requester_rc_t pldmRecvAtNetwork(mctp_eid_t eid, int networkId, int mctpFd, - uint8_t instanceId, uint8_t** pldmRespMsg, - size_t* respMsgLen) +pldm_requester_rc_t PldmInterface::pldmRecvAtNetwork(mctp_eid_t eid, + int networkId, int mctpFd, + uint8_t instanceId, + uint8_t** pldmRespMsg, + size_t* respMsgLen) { pldm_requester_rc_t rc = recvAtNetwork(eid, mctpFd, pldmRespMsg, respMsgLen, networkId);
diff --git a/interface/pldm_interface.hpp b/interface/pldm_interface.hpp index 2533ca6..a98e8fa 100644 --- a/interface/pldm_interface.hpp +++ b/interface/pldm_interface.hpp
@@ -1,6 +1,6 @@ /** - * This file is derived from the upstream implementation of send and recv - * With an extra network id parameter + * The pldmSend and pldmRecv are derived from the upstream implementation of + * send and recv With an extra network id parameter */ #pragma once @@ -11,35 +11,54 @@ #include "libpldm/pldm.h" #include "mctp.h" -/** - * @brief Sends the PLDM message to a specific network and eid - * - * @param[in] eid - Destination Eid - * @param[in] networkId - Network id for the device - * @param[in] mctpFd - Socket address - * @param[in] pldmReqMsg - Pointer to the request message - * @param[in] reqMsgLen - Length of the request message - * - * @return pldm_requester_rc_t - */ -pldm_requester_rc_t pldmSendAtNetwork(mctp_eid_t eid, int networkId, int mctpFd, - const uint8_t* pldmReqMsg, - size_t reqMsgLen); +class PldmInterface +{ + public: + PldmInterface() = default; + ~PldmInterface() = default; -/** - * @brief Receives the PLDM message from a specific network and eid - * - * @param[in] eid - Destination Eid - * @param[in] networkId - Network id for the device - * @param[in] mctpFd - Socket address - * @param[in] instanceId - Instance id for the req/resp - * @param[out] pldmRespMsg - Pointer to the response message - * @param[out] respMsgLen - Length of the response message - * - * @return pldm_requester_rc_t - */ -pldm_requester_rc_t pldmRecvAtNetwork(mctp_eid_t eid, int networkId, int mctpFd, - uint8_t instanceId, uint8_t** pldmRespMsg, - size_t* respMsgLen); + /** + * @brief Sends the PLDM message to a specific network and eid + * + * @param[in] eid - Destination Eid + * @param[in] networkId - Network id for the device + * @param[in] mctpFd - Socket address + * @param[in] pldmReqMsg - Pointer to the request message + * @param[in] reqMsgLen - Length of the request message + * + * @return pldm_requester_rc_t + */ + virtual pldm_requester_rc_t pldmSendAtNetwork(mctp_eid_t eid, int networkId, + int mctpFd, + const uint8_t* pldmReqMsg, + size_t reqMsgLen); + + /** + * @brief Receives the PLDM message from a specific network and eid + * + * @param[in] eid - Destination Eid + * @param[in] networkId - Network id for the device + * @param[in] mctpFd - Socket address + * @param[in] instanceId - Instance id for the req/resp + * @param[out] pldmRespMsg - Pointer to the response message + * @param[out] respMsgLen - Length of the response message + * + * @return pldm_requester_rc_t + */ + virtual pldm_requester_rc_t pldmRecvAtNetwork(mctp_eid_t eid, int networkId, + int mctpFd, + uint8_t instanceId, + uint8_t** pldmRespMsg, + size_t* respMsgLen); + + private: + pldm_requester_rc_t mctpRecvAtNetwork(mctp_eid_t eid, int mctpFd, + uint8_t** pldmRespMsg, + size_t* respMsgLen, int networkId); + + pldm_requester_rc_t recvAtNetwork(mctp_eid_t eid, int mctpFd, + uint8_t** pldmRespMsg, size_t* respMsgLen, + int networkId); +}; #endif // BMC_PLDM_HPP
diff --git a/meson.build b/meson.build index 1fe2184..6ab0761 100644 --- a/meson.build +++ b/meson.build
@@ -41,10 +41,6 @@ libbej_dep = dependency('libbej') libpldm_dep = dependency('libpldm') -if get_option('tests').allowed() - subdir('tests') -endif - boost = dependency( 'boost', version : '>=1.82.0', @@ -85,6 +81,27 @@ req_src, ] +rded_lib = static_library( + 'rded_lib', + [ + 'interface/pldm_interface.cpp', + 'interface/pldm_rde.cpp', + 'util/matcher/rde_match_handler.cpp', + 'util/state_machine/discovery/base/base_disc_state_machine.cpp' + ], + include_directories: headers, + implicit_include_directories: false, + dependencies: deps) + +rded_lib_dep = declare_dependency( + dependencies: deps, + include_directories: headers, + link_with: rded_lib) + +if get_option('tests').allowed() + subdir('tests') +endif + executable( 'rded', sources,
diff --git a/tests/discovery/base/base_disc_state_machine_test.cpp b/tests/discovery/base/base_disc_state_machine_test.cpp new file mode 100644 index 0000000..7b49bb5 --- /dev/null +++ b/tests/discovery/base/base_disc_state_machine_test.cpp
@@ -0,0 +1,93 @@ +#include "libpldm/base.h" + +#include "tests/pldm_interface_mock.hpp" +#include "util/state_machine/discovery/base/base_disc_state_machine.hpp" + +#include <stdplus/print.hpp> + +#include <memory> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +using ::testing::_; +using ::testing::Return; + +constexpr int testFd = 3; +constexpr int testNetId = 6; +constexpr std::string testDeviceId = "1_1_3_1"; + +class BaseDiscoveryTest : public ::testing::Test +{ + protected: + std::shared_ptr<MockPldmInterface> mockInterface = + std::make_shared<MockPldmInterface>(); +}; + +TEST_F(BaseDiscoveryTest, StateMachineRunSuccess) +{ + std::unique_ptr<BaseDiscoveryStateMachine> stateMachineBase = + std::make_unique<BaseDiscoveryStateMachine>(testFd, testDeviceId, + testNetId, mockInterface); + + EXPECT_CALL(*mockInterface, pldmSendAtNetwork(_, _, _, _, _)) + .WillOnce(Return(PLDM_REQUESTER_SUCCESS)) // send for getTid + .WillOnce(Return(PLDM_REQUESTER_SUCCESS)) // send for getTypes + .WillOnce(Return(PLDM_REQUESTER_SUCCESS)) // send for getVersions + .WillOnce(Return(PLDM_REQUESTER_SUCCESS)); // send for getCommands + + EXPECT_CALL(*mockInterface, pldmRecvAtNetwork(_, _, _, _, _, _)) + .WillOnce(Return(PLDM_REQUESTER_SUCCESS)) // receive for getTid + .WillOnce(Return(PLDM_REQUESTER_SUCCESS)) // receive for getTypes + .WillOnce(Return(PLDM_REQUESTER_SUCCESS)) // receive for getVersions + .WillOnce(Return(PLDM_REQUESTER_SUCCESS)); // receive for getCommands + + OperationStatus status = stateMachineBase.get()->run(); + stdplus::println(stderr, "Status: {}", static_cast<int>(status)); + // All invocations tested with success code (State machine runs fine) + EXPECT_EQ(status, OperationStatus::Success); +} + +TEST_F(BaseDiscoveryTest, StateMachineRunReceiveFails) +{ + // Fails at the second receive and hence discovery should fail and no more + // calls to be made to the interface + std::unique_ptr<BaseDiscoveryStateMachine> stateMachineBase = + std::make_unique<BaseDiscoveryStateMachine>(testFd, testDeviceId, + testNetId, mockInterface); + + EXPECT_CALL(*mockInterface, pldmSendAtNetwork(_, _, _, _, _)) + .WillOnce(Return(PLDM_REQUESTER_SUCCESS)) // send for getTid + .WillOnce(Return(PLDM_REQUESTER_SUCCESS)); // send for getTypes + + EXPECT_CALL(*mockInterface, pldmRecvAtNetwork(_, _, _, _, _, _)) + .WillOnce(Return(PLDM_REQUESTER_SUCCESS)) // receive for getTid + .WillOnce(Return( + PLDM_REQUESTER_RESP_MSG_TOO_SMALL)); // receive fails for getTypes + + OperationStatus status = stateMachineBase.get()->run(); + stdplus::println(stderr, "Status: {}", static_cast<int>(status)); + + // Receive failure - hence discovery fails + EXPECT_EQ(status, OperationStatus::PldmRecvFailure); +} + +TEST_F(BaseDiscoveryTest, StateMachineRunSendFails) +{ + // Fails at the first send and hence discovery should fail and no more + // calls to be made to the interface + std::unique_ptr<BaseDiscoveryStateMachine> stateMachineBase = + std::make_unique<BaseDiscoveryStateMachine>(testFd, testDeviceId, + testNetId, mockInterface); + + EXPECT_CALL(*mockInterface, pldmSendAtNetwork(_, _, _, _, _)) + .WillOnce(Return(PLDM_REQUESTER_SEND_FAIL)); // send for getTid fails + + // No receiveFrom calls are expected as it failed on first send + + OperationStatus status = stateMachineBase.get()->run(); + stdplus::println(stderr, "Status: {}", static_cast<int>(status)); + + // Send failure - hence discovery fails + EXPECT_EQ(status, OperationStatus::PldmSendFailure); +}
diff --git a/tests/meson.build b/tests/meson.build index 77b7918..cd2ed39 100644 --- a/tests/meson.build +++ b/tests/meson.build
@@ -19,22 +19,13 @@ endif endif -test( - 'test_utils', - executable( - 'test_utils', - 'rde_test.cpp', - dependencies: [gtest], - implicit_include_directories: false, - ) -) +tests = [ + 'mctp_setup_test', + 'discovery/base/base_disc_state_machine_test' +] -test( - 'test_mctp_setup', - executable( - 'test_mctp_setup', - 'mctp_setup_test.cpp', - dependencies: [gtest, req_src, stdplus], - implicit_include_directories: true, - ) -) +foreach t : tests + test(t, executable(t.underscorify(), t + '.cpp', + implicit_include_directories: false, + dependencies: [rded_lib_dep, req_src, stdplus, gtest, gmock])) +endforeach
diff --git a/tests/pldm_interface_mock.hpp b/tests/pldm_interface_mock.hpp new file mode 100644 index 0000000..54c8a7b --- /dev/null +++ b/tests/pldm_interface_mock.hpp
@@ -0,0 +1,21 @@ +#ifndef MOCK_PLDM_INTERFACE_HPP +#define MOCK_PLDM_INTERFACE_HPP + +#include "libpldm/base.h" +#include "libpldm/pldm.h" + +#include "interface/pldm_interface.hpp" + +#include "gmock/gmock.h" + +class MockPldmInterface : public PldmInterface +{ + public: + MOCK_METHOD(pldm_requester_rc_t, pldmSendAtNetwork, + (mctp_eid_t, int, int, const uint8_t*, size_t)); + + MOCK_METHOD(pldm_requester_rc_t, pldmRecvAtNetwork, + (mctp_eid_t, int, int, uint8_t, uint8_t**, size_t*)); +}; + +#endif // MOCK_PLDM_INTERFACE_HPP
diff --git a/tests/rde_test.cpp b/tests/rde_test.cpp deleted file mode 100644 index 1ad6631..0000000 --- a/tests/rde_test.cpp +++ /dev/null
@@ -1,7 +0,0 @@ -#include <gtest/gtest.h> - -TEST(DummyTest, DummyTest) -{ - // TODO(@harshtya): Add tests for RDE - EXPECT_TRUE(true); -}
diff --git a/util/state_machine/discovery/base/base_disc_state_machine.cpp b/util/state_machine/discovery/base/base_disc_state_machine.cpp index 76f5725..23973d2 100644 --- a/util/state_machine/discovery/base/base_disc_state_machine.cpp +++ b/util/state_machine/discovery/base/base_disc_state_machine.cpp
@@ -3,15 +3,14 @@ #include "libpldm/base.h" #include "libpldm/pldm.h" -#include "interface/pldm_interface.hpp" - #include <stdplus/print.hpp> #include <util/common.hpp> BaseDiscoveryStateMachine::BaseDiscoveryStateMachine( - int fd, const std::string& deviceName, int netId) : + int fd, const std::string& deviceName, int netId, + std::shared_ptr<PldmInterface> pldmInterface) : fd(fd), - deviceName(deviceName), netId(netId) + deviceName(deviceName), netId(netId), pldmInterface(pldmInterface) { stdplus::print(stderr, "Initializing base discovery state machine...\n"); this->initialized = true; @@ -96,8 +95,8 @@ return OperationStatus::EncodingRequestFailure; } - if (pldmSendAtNetwork(this->eid, this->netId, this->fd, requestMsg.data(), - requestMsg.size())) + if (pldmInterface->pldmSendAtNetwork(this->eid, this->netId, this->fd, + requestMsg.data(), requestMsg.size())) { this->requesterStatus = StateMachineStatus::RequestFailed; return OperationStatus::PldmSendFailure; @@ -108,8 +107,9 @@ size_t responseMsgSize = sizeof(pldm_msg_hdr) + PLDM_GET_TID_RESP_BYTES; auto responsePtr = reinterpret_cast<struct pldm_msg*>(responseMsg); - if (pldmRecvAtNetwork(this->eid, this->netId, this->fd, this->instanceId, - &responseMsg, &responseMsgSize)) + if (pldmInterface->pldmRecvAtNetwork(this->eid, this->netId, this->fd, + this->instanceId, &responseMsg, + &responseMsgSize)) { this->requesterStatus = StateMachineStatus::RequestFailed; return OperationStatus::PldmRecvFailure; @@ -145,8 +145,8 @@ return OperationStatus::EncodingRequestFailure; } - if (pldmSendAtNetwork(this->eid, this->netId, this->fd, requestMsg.data(), - requestMsg.size())) + if (pldmInterface->pldmSendAtNetwork(this->eid, this->netId, this->fd, + requestMsg.data(), requestMsg.size())) { this->requesterStatus = StateMachineStatus::RequestFailed; return OperationStatus::PldmSendFailure; @@ -157,8 +157,9 @@ size_t responseMsgSize = response.size(); auto responsePtr = reinterpret_cast<struct pldm_msg*>(responseMsg); - if (pldmRecvAtNetwork(this->eid, this->netId, this->fd, this->instanceId, - &responseMsg, &responseMsgSize)) + if (pldmInterface->pldmRecvAtNetwork(this->eid, this->netId, this->fd, + this->instanceId, &responseMsg, + &responseMsgSize)) { this->requesterStatus = StateMachineStatus::RequestFailed; return OperationStatus::PldmRecvFailure; @@ -198,8 +199,8 @@ return OperationStatus::EncodingRequestFailure; } - if (pldmSendAtNetwork(this->eid, this->netId, this->fd, requestMsg.data(), - requestMsg.size())) + if (pldmInterface->pldmSendAtNetwork(this->eid, this->netId, this->fd, + requestMsg.data(), requestMsg.size())) { this->requesterStatus = StateMachineStatus::RequestFailed; return OperationStatus::PldmSendFailure; @@ -211,8 +212,9 @@ size_t responseMsgSize = response.size(); auto responsePtr = reinterpret_cast<struct pldm_msg*>(responseMsg); - if (pldmRecvAtNetwork(this->eid, this->netId, this->fd, this->instanceId, - &responseMsg, &responseMsgSize)) + if (pldmInterface->pldmRecvAtNetwork(this->eid, this->netId, this->fd, + this->instanceId, &responseMsg, + &responseMsgSize)) { this->requesterStatus = StateMachineStatus::RequestFailed; return OperationStatus::PldmRecvFailure; @@ -254,8 +256,8 @@ return OperationStatus::EncodingRequestFailure; } - if (pldmSendAtNetwork(this->eid, this->netId, this->fd, requestMsg.data(), - requestMsg.size())) + if (pldmInterface->pldmSendAtNetwork(this->eid, this->netId, this->fd, + requestMsg.data(), requestMsg.size())) { this->requesterStatus = StateMachineStatus::RequestFailed; return OperationStatus::PldmSendFailure; @@ -266,8 +268,9 @@ size_t responseMsgSize = response.size(); auto responsePtr = reinterpret_cast<struct pldm_msg*>(responseMsg); - if (pldmRecvAtNetwork(this->eid, this->netId, this->fd, this->instanceId, - &responseMsg, &responseMsgSize)) + if (pldmInterface->pldmRecvAtNetwork(this->eid, this->netId, this->fd, + this->instanceId, &responseMsg, + &responseMsgSize)) { this->requesterStatus = StateMachineStatus::RequestFailed; return OperationStatus::PldmRecvFailure;
diff --git a/util/state_machine/discovery/base/base_disc_state_machine.hpp b/util/state_machine/discovery/base/base_disc_state_machine.hpp index 55091bb..d773fcc 100644 --- a/util/state_machine/discovery/base/base_disc_state_machine.hpp +++ b/util/state_machine/discovery/base/base_disc_state_machine.hpp
@@ -4,9 +4,11 @@ #include "libpldm/base.h" #include "libpldm/pldm.h" +#include "interface/pldm_interface.hpp" #include "state_machine_factory.hpp" #include <array> +#include <memory> #include <optional> #include <unordered_map> @@ -26,7 +28,8 @@ class BaseDiscoveryStateMachine : public StateMachineFactory { public: - BaseDiscoveryStateMachine(int fd, const std::string& deviceName, int netId); + BaseDiscoveryStateMachine(int fd, const std::string& deviceName, int netId, + std::shared_ptr<PldmInterface> pldmInterface); OperationStatus run() override; @@ -82,6 +85,8 @@ std::array<std::array<uint8_t, PLDM_MAX_CMDS_PER_TYPE>, PLDM_MAX_TYPES> pldmCommands; std::array<ver32_t, PLDM_MAX_TYPES> pldmVersions; + + std::shared_ptr<PldmInterface> pldmInterface; }; #endif // BASE_DISC_STATE_MACHINE_HPP