Add test infra to support state machine testing
Following changes are included:
* Change in the pldm interface to make it test friendly
* Added test cases for the base discovery state machine
Google-Bug-Id: 320779802
Change-Id: I51abc449295c348c06f761c6e00da6839ba40db3
Signed-off-by: Harsh Tyagi <harshtya@google.com>
diff --git a/interface/pldm_interface.cpp b/interface/pldm_interface.cpp
index 7795a79..33e88ec 100644
--- a/interface/pldm_interface.cpp
+++ b/interface/pldm_interface.cpp
@@ -10,9 +10,10 @@
constexpr uint8_t MCTP_MSG_TYPE_PLDM = 1;
-pldm_requester_rc_t pldmSendAtNetwork(mctp_eid_t eid, int networkId, int mctpFd,
- const uint8_t* pldmReqMsg,
- size_t reqMsgLen)
+pldm_requester_rc_t PldmInterface::pldmSendAtNetwork(mctp_eid_t eid,
+ int networkId, int mctpFd,
+ const uint8_t* pldmReqMsg,
+ size_t reqMsgLen)
{
struct sockaddr_mctp addr;
addr.smctp_family = AF_MCTP;
@@ -37,9 +38,10 @@
return PLDM_REQUESTER_SUCCESS;
}
-pldm_requester_rc_t mctpRecvAtNetwork(mctp_eid_t eid, int mctpFd,
- uint8_t** pldmRespMsg, size_t* respMsgLen,
- int networkId)
+pldm_requester_rc_t PldmInterface::mctpRecvAtNetwork(mctp_eid_t eid, int mctpFd,
+ uint8_t** pldmRespMsg,
+ size_t* respMsgLen,
+ int networkId)
{
ssize_t minLen = sizeof(struct pldm_msg_hdr);
struct sockaddr_mctp addr;
@@ -87,9 +89,10 @@
return PLDM_REQUESTER_SUCCESS;
}
-pldm_requester_rc_t recvAtNetwork(mctp_eid_t eid, int mctpFd,
- uint8_t** pldmRespMsg, size_t* respMsgLen,
- int networkId)
+pldm_requester_rc_t PldmInterface::recvAtNetwork(mctp_eid_t eid, int mctpFd,
+ uint8_t** pldmRespMsg,
+ size_t* respMsgLen,
+ int networkId)
{
pldm_requester_rc_t rc =
mctpRecvAtNetwork(eid, mctpFd, pldmRespMsg, respMsgLen, networkId);
@@ -127,9 +130,11 @@
return PLDM_REQUESTER_SUCCESS;
}
-pldm_requester_rc_t pldmRecvAtNetwork(mctp_eid_t eid, int networkId, int mctpFd,
- uint8_t instanceId, uint8_t** pldmRespMsg,
- size_t* respMsgLen)
+pldm_requester_rc_t PldmInterface::pldmRecvAtNetwork(mctp_eid_t eid,
+ int networkId, int mctpFd,
+ uint8_t instanceId,
+ uint8_t** pldmRespMsg,
+ size_t* respMsgLen)
{
pldm_requester_rc_t rc =
recvAtNetwork(eid, mctpFd, pldmRespMsg, respMsgLen, networkId);
diff --git a/interface/pldm_interface.hpp b/interface/pldm_interface.hpp
index 2533ca6..a98e8fa 100644
--- a/interface/pldm_interface.hpp
+++ b/interface/pldm_interface.hpp
@@ -1,6 +1,6 @@
/**
- * This file is derived from the upstream implementation of send and recv
- * With an extra network id parameter
+ * The pldmSend and pldmRecv are derived from the upstream implementation of
+ * send and recv With an extra network id parameter
*/
#pragma once
@@ -11,35 +11,54 @@
#include "libpldm/pldm.h"
#include "mctp.h"
-/**
- * @brief Sends the PLDM message to a specific network and eid
- *
- * @param[in] eid - Destination Eid
- * @param[in] networkId - Network id for the device
- * @param[in] mctpFd - Socket address
- * @param[in] pldmReqMsg - Pointer to the request message
- * @param[in] reqMsgLen - Length of the request message
- *
- * @return pldm_requester_rc_t
- */
-pldm_requester_rc_t pldmSendAtNetwork(mctp_eid_t eid, int networkId, int mctpFd,
- const uint8_t* pldmReqMsg,
- size_t reqMsgLen);
+class PldmInterface
+{
+ public:
+ PldmInterface() = default;
+ ~PldmInterface() = default;
-/**
- * @brief Receives the PLDM message from a specific network and eid
- *
- * @param[in] eid - Destination Eid
- * @param[in] networkId - Network id for the device
- * @param[in] mctpFd - Socket address
- * @param[in] instanceId - Instance id for the req/resp
- * @param[out] pldmRespMsg - Pointer to the response message
- * @param[out] respMsgLen - Length of the response message
- *
- * @return pldm_requester_rc_t
- */
-pldm_requester_rc_t pldmRecvAtNetwork(mctp_eid_t eid, int networkId, int mctpFd,
- uint8_t instanceId, uint8_t** pldmRespMsg,
- size_t* respMsgLen);
+ /**
+ * @brief Sends the PLDM message to a specific network and eid
+ *
+ * @param[in] eid - Destination Eid
+ * @param[in] networkId - Network id for the device
+ * @param[in] mctpFd - Socket address
+ * @param[in] pldmReqMsg - Pointer to the request message
+ * @param[in] reqMsgLen - Length of the request message
+ *
+ * @return pldm_requester_rc_t
+ */
+ virtual pldm_requester_rc_t pldmSendAtNetwork(mctp_eid_t eid, int networkId,
+ int mctpFd,
+ const uint8_t* pldmReqMsg,
+ size_t reqMsgLen);
+
+ /**
+ * @brief Receives the PLDM message from a specific network and eid
+ *
+ * @param[in] eid - Destination Eid
+ * @param[in] networkId - Network id for the device
+ * @param[in] mctpFd - Socket address
+ * @param[in] instanceId - Instance id for the req/resp
+ * @param[out] pldmRespMsg - Pointer to the response message
+ * @param[out] respMsgLen - Length of the response message
+ *
+ * @return pldm_requester_rc_t
+ */
+ virtual pldm_requester_rc_t pldmRecvAtNetwork(mctp_eid_t eid, int networkId,
+ int mctpFd,
+ uint8_t instanceId,
+ uint8_t** pldmRespMsg,
+ size_t* respMsgLen);
+
+ private:
+ pldm_requester_rc_t mctpRecvAtNetwork(mctp_eid_t eid, int mctpFd,
+ uint8_t** pldmRespMsg,
+ size_t* respMsgLen, int networkId);
+
+ pldm_requester_rc_t recvAtNetwork(mctp_eid_t eid, int mctpFd,
+ uint8_t** pldmRespMsg, size_t* respMsgLen,
+ int networkId);
+};
#endif // BMC_PLDM_HPP
diff --git a/meson.build b/meson.build
index 1fe2184..6ab0761 100644
--- a/meson.build
+++ b/meson.build
@@ -41,10 +41,6 @@
libbej_dep = dependency('libbej')
libpldm_dep = dependency('libpldm')
-if get_option('tests').allowed()
- subdir('tests')
-endif
-
boost = dependency(
'boost',
version : '>=1.82.0',
@@ -85,6 +81,27 @@
req_src,
]
+rded_lib = static_library(
+ 'rded_lib',
+ [
+ 'interface/pldm_interface.cpp',
+ 'interface/pldm_rde.cpp',
+ 'util/matcher/rde_match_handler.cpp',
+ 'util/state_machine/discovery/base/base_disc_state_machine.cpp'
+ ],
+ include_directories: headers,
+ implicit_include_directories: false,
+ dependencies: deps)
+
+rded_lib_dep = declare_dependency(
+ dependencies: deps,
+ include_directories: headers,
+ link_with: rded_lib)
+
+if get_option('tests').allowed()
+ subdir('tests')
+endif
+
executable(
'rded',
sources,
diff --git a/tests/discovery/base/base_disc_state_machine_test.cpp b/tests/discovery/base/base_disc_state_machine_test.cpp
new file mode 100644
index 0000000..7b49bb5
--- /dev/null
+++ b/tests/discovery/base/base_disc_state_machine_test.cpp
@@ -0,0 +1,93 @@
+#include "libpldm/base.h"
+
+#include "tests/pldm_interface_mock.hpp"
+#include "util/state_machine/discovery/base/base_disc_state_machine.hpp"
+
+#include <stdplus/print.hpp>
+
+#include <memory>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+using ::testing::_;
+using ::testing::Return;
+
+constexpr int testFd = 3;
+constexpr int testNetId = 6;
+constexpr std::string testDeviceId = "1_1_3_1";
+
+class BaseDiscoveryTest : public ::testing::Test
+{
+ protected:
+ std::shared_ptr<MockPldmInterface> mockInterface =
+ std::make_shared<MockPldmInterface>();
+};
+
+TEST_F(BaseDiscoveryTest, StateMachineRunSuccess)
+{
+ std::unique_ptr<BaseDiscoveryStateMachine> stateMachineBase =
+ std::make_unique<BaseDiscoveryStateMachine>(testFd, testDeviceId,
+ testNetId, mockInterface);
+
+ EXPECT_CALL(*mockInterface, pldmSendAtNetwork(_, _, _, _, _))
+ .WillOnce(Return(PLDM_REQUESTER_SUCCESS)) // send for getTid
+ .WillOnce(Return(PLDM_REQUESTER_SUCCESS)) // send for getTypes
+ .WillOnce(Return(PLDM_REQUESTER_SUCCESS)) // send for getVersions
+ .WillOnce(Return(PLDM_REQUESTER_SUCCESS)); // send for getCommands
+
+ EXPECT_CALL(*mockInterface, pldmRecvAtNetwork(_, _, _, _, _, _))
+ .WillOnce(Return(PLDM_REQUESTER_SUCCESS)) // receive for getTid
+ .WillOnce(Return(PLDM_REQUESTER_SUCCESS)) // receive for getTypes
+ .WillOnce(Return(PLDM_REQUESTER_SUCCESS)) // receive for getVersions
+ .WillOnce(Return(PLDM_REQUESTER_SUCCESS)); // receive for getCommands
+
+ OperationStatus status = stateMachineBase.get()->run();
+ stdplus::println(stderr, "Status: {}", static_cast<int>(status));
+ // All invocations tested with success code (State machine runs fine)
+ EXPECT_EQ(status, OperationStatus::Success);
+}
+
+TEST_F(BaseDiscoveryTest, StateMachineRunReceiveFails)
+{
+ // Fails at the second receive and hence discovery should fail and no more
+ // calls to be made to the interface
+ std::unique_ptr<BaseDiscoveryStateMachine> stateMachineBase =
+ std::make_unique<BaseDiscoveryStateMachine>(testFd, testDeviceId,
+ testNetId, mockInterface);
+
+ EXPECT_CALL(*mockInterface, pldmSendAtNetwork(_, _, _, _, _))
+ .WillOnce(Return(PLDM_REQUESTER_SUCCESS)) // send for getTid
+ .WillOnce(Return(PLDM_REQUESTER_SUCCESS)); // send for getTypes
+
+ EXPECT_CALL(*mockInterface, pldmRecvAtNetwork(_, _, _, _, _, _))
+ .WillOnce(Return(PLDM_REQUESTER_SUCCESS)) // receive for getTid
+ .WillOnce(Return(
+ PLDM_REQUESTER_RESP_MSG_TOO_SMALL)); // receive fails for getTypes
+
+ OperationStatus status = stateMachineBase.get()->run();
+ stdplus::println(stderr, "Status: {}", static_cast<int>(status));
+
+ // Receive failure - hence discovery fails
+ EXPECT_EQ(status, OperationStatus::PldmRecvFailure);
+}
+
+TEST_F(BaseDiscoveryTest, StateMachineRunSendFails)
+{
+ // Fails at the first send and hence discovery should fail and no more
+ // calls to be made to the interface
+ std::unique_ptr<BaseDiscoveryStateMachine> stateMachineBase =
+ std::make_unique<BaseDiscoveryStateMachine>(testFd, testDeviceId,
+ testNetId, mockInterface);
+
+ EXPECT_CALL(*mockInterface, pldmSendAtNetwork(_, _, _, _, _))
+ .WillOnce(Return(PLDM_REQUESTER_SEND_FAIL)); // send for getTid fails
+
+ // No receiveFrom calls are expected as it failed on first send
+
+ OperationStatus status = stateMachineBase.get()->run();
+ stdplus::println(stderr, "Status: {}", static_cast<int>(status));
+
+ // Send failure - hence discovery fails
+ EXPECT_EQ(status, OperationStatus::PldmSendFailure);
+}
diff --git a/tests/meson.build b/tests/meson.build
index 77b7918..cd2ed39 100644
--- a/tests/meson.build
+++ b/tests/meson.build
@@ -19,22 +19,13 @@
endif
endif
-test(
- 'test_utils',
- executable(
- 'test_utils',
- 'rde_test.cpp',
- dependencies: [gtest],
- implicit_include_directories: false,
- )
-)
+tests = [
+ 'mctp_setup_test',
+ 'discovery/base/base_disc_state_machine_test'
+]
-test(
- 'test_mctp_setup',
- executable(
- 'test_mctp_setup',
- 'mctp_setup_test.cpp',
- dependencies: [gtest, req_src, stdplus],
- implicit_include_directories: true,
- )
-)
+foreach t : tests
+ test(t, executable(t.underscorify(), t + '.cpp',
+ implicit_include_directories: false,
+ dependencies: [rded_lib_dep, req_src, stdplus, gtest, gmock]))
+endforeach
diff --git a/tests/pldm_interface_mock.hpp b/tests/pldm_interface_mock.hpp
new file mode 100644
index 0000000..54c8a7b
--- /dev/null
+++ b/tests/pldm_interface_mock.hpp
@@ -0,0 +1,21 @@
+#ifndef MOCK_PLDM_INTERFACE_HPP
+#define MOCK_PLDM_INTERFACE_HPP
+
+#include "libpldm/base.h"
+#include "libpldm/pldm.h"
+
+#include "interface/pldm_interface.hpp"
+
+#include "gmock/gmock.h"
+
+class MockPldmInterface : public PldmInterface
+{
+ public:
+ MOCK_METHOD(pldm_requester_rc_t, pldmSendAtNetwork,
+ (mctp_eid_t, int, int, const uint8_t*, size_t));
+
+ MOCK_METHOD(pldm_requester_rc_t, pldmRecvAtNetwork,
+ (mctp_eid_t, int, int, uint8_t, uint8_t**, size_t*));
+};
+
+#endif // MOCK_PLDM_INTERFACE_HPP
diff --git a/tests/rde_test.cpp b/tests/rde_test.cpp
deleted file mode 100644
index 1ad6631..0000000
--- a/tests/rde_test.cpp
+++ /dev/null
@@ -1,7 +0,0 @@
-#include <gtest/gtest.h>
-
-TEST(DummyTest, DummyTest)
-{
- // TODO(@harshtya): Add tests for RDE
- EXPECT_TRUE(true);
-}
diff --git a/util/state_machine/discovery/base/base_disc_state_machine.cpp b/util/state_machine/discovery/base/base_disc_state_machine.cpp
index 76f5725..23973d2 100644
--- a/util/state_machine/discovery/base/base_disc_state_machine.cpp
+++ b/util/state_machine/discovery/base/base_disc_state_machine.cpp
@@ -3,15 +3,14 @@
#include "libpldm/base.h"
#include "libpldm/pldm.h"
-#include "interface/pldm_interface.hpp"
-
#include <stdplus/print.hpp>
#include <util/common.hpp>
BaseDiscoveryStateMachine::BaseDiscoveryStateMachine(
- int fd, const std::string& deviceName, int netId) :
+ int fd, const std::string& deviceName, int netId,
+ std::shared_ptr<PldmInterface> pldmInterface) :
fd(fd),
- deviceName(deviceName), netId(netId)
+ deviceName(deviceName), netId(netId), pldmInterface(pldmInterface)
{
stdplus::print(stderr, "Initializing base discovery state machine...\n");
this->initialized = true;
@@ -96,8 +95,8 @@
return OperationStatus::EncodingRequestFailure;
}
- if (pldmSendAtNetwork(this->eid, this->netId, this->fd, requestMsg.data(),
- requestMsg.size()))
+ if (pldmInterface->pldmSendAtNetwork(this->eid, this->netId, this->fd,
+ requestMsg.data(), requestMsg.size()))
{
this->requesterStatus = StateMachineStatus::RequestFailed;
return OperationStatus::PldmSendFailure;
@@ -108,8 +107,9 @@
size_t responseMsgSize = sizeof(pldm_msg_hdr) + PLDM_GET_TID_RESP_BYTES;
auto responsePtr = reinterpret_cast<struct pldm_msg*>(responseMsg);
- if (pldmRecvAtNetwork(this->eid, this->netId, this->fd, this->instanceId,
- &responseMsg, &responseMsgSize))
+ if (pldmInterface->pldmRecvAtNetwork(this->eid, this->netId, this->fd,
+ this->instanceId, &responseMsg,
+ &responseMsgSize))
{
this->requesterStatus = StateMachineStatus::RequestFailed;
return OperationStatus::PldmRecvFailure;
@@ -145,8 +145,8 @@
return OperationStatus::EncodingRequestFailure;
}
- if (pldmSendAtNetwork(this->eid, this->netId, this->fd, requestMsg.data(),
- requestMsg.size()))
+ if (pldmInterface->pldmSendAtNetwork(this->eid, this->netId, this->fd,
+ requestMsg.data(), requestMsg.size()))
{
this->requesterStatus = StateMachineStatus::RequestFailed;
return OperationStatus::PldmSendFailure;
@@ -157,8 +157,9 @@
size_t responseMsgSize = response.size();
auto responsePtr = reinterpret_cast<struct pldm_msg*>(responseMsg);
- if (pldmRecvAtNetwork(this->eid, this->netId, this->fd, this->instanceId,
- &responseMsg, &responseMsgSize))
+ if (pldmInterface->pldmRecvAtNetwork(this->eid, this->netId, this->fd,
+ this->instanceId, &responseMsg,
+ &responseMsgSize))
{
this->requesterStatus = StateMachineStatus::RequestFailed;
return OperationStatus::PldmRecvFailure;
@@ -198,8 +199,8 @@
return OperationStatus::EncodingRequestFailure;
}
- if (pldmSendAtNetwork(this->eid, this->netId, this->fd, requestMsg.data(),
- requestMsg.size()))
+ if (pldmInterface->pldmSendAtNetwork(this->eid, this->netId, this->fd,
+ requestMsg.data(), requestMsg.size()))
{
this->requesterStatus = StateMachineStatus::RequestFailed;
return OperationStatus::PldmSendFailure;
@@ -211,8 +212,9 @@
size_t responseMsgSize = response.size();
auto responsePtr = reinterpret_cast<struct pldm_msg*>(responseMsg);
- if (pldmRecvAtNetwork(this->eid, this->netId, this->fd, this->instanceId,
- &responseMsg, &responseMsgSize))
+ if (pldmInterface->pldmRecvAtNetwork(this->eid, this->netId, this->fd,
+ this->instanceId, &responseMsg,
+ &responseMsgSize))
{
this->requesterStatus = StateMachineStatus::RequestFailed;
return OperationStatus::PldmRecvFailure;
@@ -254,8 +256,8 @@
return OperationStatus::EncodingRequestFailure;
}
- if (pldmSendAtNetwork(this->eid, this->netId, this->fd, requestMsg.data(),
- requestMsg.size()))
+ if (pldmInterface->pldmSendAtNetwork(this->eid, this->netId, this->fd,
+ requestMsg.data(), requestMsg.size()))
{
this->requesterStatus = StateMachineStatus::RequestFailed;
return OperationStatus::PldmSendFailure;
@@ -266,8 +268,9 @@
size_t responseMsgSize = response.size();
auto responsePtr = reinterpret_cast<struct pldm_msg*>(responseMsg);
- if (pldmRecvAtNetwork(this->eid, this->netId, this->fd, this->instanceId,
- &responseMsg, &responseMsgSize))
+ if (pldmInterface->pldmRecvAtNetwork(this->eid, this->netId, this->fd,
+ this->instanceId, &responseMsg,
+ &responseMsgSize))
{
this->requesterStatus = StateMachineStatus::RequestFailed;
return OperationStatus::PldmRecvFailure;
diff --git a/util/state_machine/discovery/base/base_disc_state_machine.hpp b/util/state_machine/discovery/base/base_disc_state_machine.hpp
index 55091bb..d773fcc 100644
--- a/util/state_machine/discovery/base/base_disc_state_machine.hpp
+++ b/util/state_machine/discovery/base/base_disc_state_machine.hpp
@@ -4,9 +4,11 @@
#include "libpldm/base.h"
#include "libpldm/pldm.h"
+#include "interface/pldm_interface.hpp"
#include "state_machine_factory.hpp"
#include <array>
+#include <memory>
#include <optional>
#include <unordered_map>
@@ -26,7 +28,8 @@
class BaseDiscoveryStateMachine : public StateMachineFactory
{
public:
- BaseDiscoveryStateMachine(int fd, const std::string& deviceName, int netId);
+ BaseDiscoveryStateMachine(int fd, const std::string& deviceName, int netId,
+ std::shared_ptr<PldmInterface> pldmInterface);
OperationStatus run() override;
@@ -82,6 +85,8 @@
std::array<std::array<uint8_t, PLDM_MAX_CMDS_PER_TYPE>, PLDM_MAX_TYPES>
pldmCommands;
std::array<ver32_t, PLDM_MAX_TYPES> pldmVersions;
+
+ std::shared_ptr<PldmInterface> pldmInterface;
};
#endif // BASE_DISC_STATE_MACHINE_HPP