Add test infra to support state machine testing

Following changes are included:

* Change in the pldm interface to make it test friendly
* Added test cases for the base discovery state machine

Google-Bug-Id: 320779802
Change-Id: I51abc449295c348c06f761c6e00da6839ba40db3
Signed-off-by: Harsh Tyagi <harshtya@google.com>
diff --git a/interface/pldm_interface.cpp b/interface/pldm_interface.cpp
index 7795a79..33e88ec 100644
--- a/interface/pldm_interface.cpp
+++ b/interface/pldm_interface.cpp
@@ -10,9 +10,10 @@
 
 constexpr uint8_t MCTP_MSG_TYPE_PLDM = 1;
 
-pldm_requester_rc_t pldmSendAtNetwork(mctp_eid_t eid, int networkId, int mctpFd,
-                                      const uint8_t* pldmReqMsg,
-                                      size_t reqMsgLen)
+pldm_requester_rc_t PldmInterface::pldmSendAtNetwork(mctp_eid_t eid,
+                                                     int networkId, int mctpFd,
+                                                     const uint8_t* pldmReqMsg,
+                                                     size_t reqMsgLen)
 {
     struct sockaddr_mctp addr;
     addr.smctp_family = AF_MCTP;
@@ -37,9 +38,10 @@
     return PLDM_REQUESTER_SUCCESS;
 }
 
-pldm_requester_rc_t mctpRecvAtNetwork(mctp_eid_t eid, int mctpFd,
-                                      uint8_t** pldmRespMsg, size_t* respMsgLen,
-                                      int networkId)
+pldm_requester_rc_t PldmInterface::mctpRecvAtNetwork(mctp_eid_t eid, int mctpFd,
+                                                     uint8_t** pldmRespMsg,
+                                                     size_t* respMsgLen,
+                                                     int networkId)
 {
     ssize_t minLen = sizeof(struct pldm_msg_hdr);
     struct sockaddr_mctp addr;
@@ -87,9 +89,10 @@
     return PLDM_REQUESTER_SUCCESS;
 }
 
-pldm_requester_rc_t recvAtNetwork(mctp_eid_t eid, int mctpFd,
-                                  uint8_t** pldmRespMsg, size_t* respMsgLen,
-                                  int networkId)
+pldm_requester_rc_t PldmInterface::recvAtNetwork(mctp_eid_t eid, int mctpFd,
+                                                 uint8_t** pldmRespMsg,
+                                                 size_t* respMsgLen,
+                                                 int networkId)
 {
     pldm_requester_rc_t rc =
         mctpRecvAtNetwork(eid, mctpFd, pldmRespMsg, respMsgLen, networkId);
@@ -127,9 +130,11 @@
     return PLDM_REQUESTER_SUCCESS;
 }
 
-pldm_requester_rc_t pldmRecvAtNetwork(mctp_eid_t eid, int networkId, int mctpFd,
-                                      uint8_t instanceId, uint8_t** pldmRespMsg,
-                                      size_t* respMsgLen)
+pldm_requester_rc_t PldmInterface::pldmRecvAtNetwork(mctp_eid_t eid,
+                                                     int networkId, int mctpFd,
+                                                     uint8_t instanceId,
+                                                     uint8_t** pldmRespMsg,
+                                                     size_t* respMsgLen)
 {
     pldm_requester_rc_t rc =
         recvAtNetwork(eid, mctpFd, pldmRespMsg, respMsgLen, networkId);
diff --git a/interface/pldm_interface.hpp b/interface/pldm_interface.hpp
index 2533ca6..a98e8fa 100644
--- a/interface/pldm_interface.hpp
+++ b/interface/pldm_interface.hpp
@@ -1,6 +1,6 @@
 /**
- * This file is derived from the upstream implementation of send and recv
- * With an extra network id parameter
+ * The pldmSend and pldmRecv are derived from the upstream implementation of
+ * send and recv With an extra network id parameter
  */
 #pragma once
 
@@ -11,35 +11,54 @@
 #include "libpldm/pldm.h"
 #include "mctp.h"
 
-/**
- * @brief Sends the PLDM message to a specific network and eid
- *
- * @param[in] eid - Destination Eid
- * @param[in] networkId - Network id for the device
- * @param[in] mctpFd - Socket address
- * @param[in] pldmReqMsg - Pointer to the request message
- * @param[in] reqMsgLen - Length of the request message
- *
- * @return pldm_requester_rc_t
- */
-pldm_requester_rc_t pldmSendAtNetwork(mctp_eid_t eid, int networkId, int mctpFd,
-                                      const uint8_t* pldmReqMsg,
-                                      size_t reqMsgLen);
+class PldmInterface
+{
+  public:
+    PldmInterface() = default;
+    ~PldmInterface() = default;
 
-/**
- * @brief Receives the PLDM message from a specific network and eid
- *
- * @param[in] eid - Destination Eid
- * @param[in] networkId - Network id for the device
- * @param[in] mctpFd - Socket address
- * @param[in] instanceId - Instance id for the req/resp
- * @param[out] pldmRespMsg - Pointer to the response message
- * @param[out] respMsgLen - Length of the response message
- *
- * @return pldm_requester_rc_t
- */
-pldm_requester_rc_t pldmRecvAtNetwork(mctp_eid_t eid, int networkId, int mctpFd,
-                                      uint8_t instanceId, uint8_t** pldmRespMsg,
-                                      size_t* respMsgLen);
+    /**
+     * @brief Sends the PLDM message to a specific network and eid
+     *
+     * @param[in] eid - Destination Eid
+     * @param[in] networkId - Network id for the device
+     * @param[in] mctpFd - Socket address
+     * @param[in] pldmReqMsg - Pointer to the request message
+     * @param[in] reqMsgLen - Length of the request message
+     *
+     * @return pldm_requester_rc_t
+     */
+    virtual pldm_requester_rc_t pldmSendAtNetwork(mctp_eid_t eid, int networkId,
+                                                  int mctpFd,
+                                                  const uint8_t* pldmReqMsg,
+                                                  size_t reqMsgLen);
+
+    /**
+     * @brief Receives the PLDM message from a specific network and eid
+     *
+     * @param[in] eid - Destination Eid
+     * @param[in] networkId - Network id for the device
+     * @param[in] mctpFd - Socket address
+     * @param[in] instanceId - Instance id for the req/resp
+     * @param[out] pldmRespMsg - Pointer to the response message
+     * @param[out] respMsgLen - Length of the response message
+     *
+     * @return pldm_requester_rc_t
+     */
+    virtual pldm_requester_rc_t pldmRecvAtNetwork(mctp_eid_t eid, int networkId,
+                                                  int mctpFd,
+                                                  uint8_t instanceId,
+                                                  uint8_t** pldmRespMsg,
+                                                  size_t* respMsgLen);
+
+  private:
+    pldm_requester_rc_t mctpRecvAtNetwork(mctp_eid_t eid, int mctpFd,
+                                          uint8_t** pldmRespMsg,
+                                          size_t* respMsgLen, int networkId);
+
+    pldm_requester_rc_t recvAtNetwork(mctp_eid_t eid, int mctpFd,
+                                      uint8_t** pldmRespMsg, size_t* respMsgLen,
+                                      int networkId);
+};
 
 #endif // BMC_PLDM_HPP
diff --git a/meson.build b/meson.build
index 1fe2184..6ab0761 100644
--- a/meson.build
+++ b/meson.build
@@ -41,10 +41,6 @@
 libbej_dep = dependency('libbej')
 libpldm_dep = dependency('libpldm')
 
-if get_option('tests').allowed()
-  subdir('tests')
-endif
-
 boost = dependency(
   'boost',
   version : '>=1.82.0',
@@ -85,6 +81,27 @@
   req_src,
 ]
 
+rded_lib = static_library(
+  'rded_lib',
+  [
+    'interface/pldm_interface.cpp',
+    'interface/pldm_rde.cpp',
+    'util/matcher/rde_match_handler.cpp',
+    'util/state_machine/discovery/base/base_disc_state_machine.cpp'
+  ],
+  include_directories: headers,
+  implicit_include_directories: false,
+  dependencies: deps)
+
+rded_lib_dep = declare_dependency(
+  dependencies: deps,
+  include_directories: headers,
+  link_with: rded_lib)
+
+if get_option('tests').allowed()
+  subdir('tests')
+endif
+
 executable(
   'rded',
   sources,
diff --git a/tests/discovery/base/base_disc_state_machine_test.cpp b/tests/discovery/base/base_disc_state_machine_test.cpp
new file mode 100644
index 0000000..7b49bb5
--- /dev/null
+++ b/tests/discovery/base/base_disc_state_machine_test.cpp
@@ -0,0 +1,93 @@
+#include "libpldm/base.h"
+
+#include "tests/pldm_interface_mock.hpp"
+#include "util/state_machine/discovery/base/base_disc_state_machine.hpp"
+
+#include <stdplus/print.hpp>
+
+#include <memory>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+using ::testing::_;
+using ::testing::Return;
+
+constexpr int testFd = 3;
+constexpr int testNetId = 6;
+constexpr std::string testDeviceId = "1_1_3_1";
+
+class BaseDiscoveryTest : public ::testing::Test
+{
+  protected:
+    std::shared_ptr<MockPldmInterface> mockInterface =
+        std::make_shared<MockPldmInterface>();
+};
+
+TEST_F(BaseDiscoveryTest, StateMachineRunSuccess)
+{
+    std::unique_ptr<BaseDiscoveryStateMachine> stateMachineBase =
+        std::make_unique<BaseDiscoveryStateMachine>(testFd, testDeviceId,
+                                                    testNetId, mockInterface);
+
+    EXPECT_CALL(*mockInterface, pldmSendAtNetwork(_, _, _, _, _))
+        .WillOnce(Return(PLDM_REQUESTER_SUCCESS))  // send for getTid
+        .WillOnce(Return(PLDM_REQUESTER_SUCCESS))  // send for getTypes
+        .WillOnce(Return(PLDM_REQUESTER_SUCCESS))  // send for getVersions
+        .WillOnce(Return(PLDM_REQUESTER_SUCCESS)); // send for getCommands
+
+    EXPECT_CALL(*mockInterface, pldmRecvAtNetwork(_, _, _, _, _, _))
+        .WillOnce(Return(PLDM_REQUESTER_SUCCESS))  // receive for getTid
+        .WillOnce(Return(PLDM_REQUESTER_SUCCESS))  // receive for getTypes
+        .WillOnce(Return(PLDM_REQUESTER_SUCCESS))  // receive for getVersions
+        .WillOnce(Return(PLDM_REQUESTER_SUCCESS)); // receive for getCommands
+
+    OperationStatus status = stateMachineBase.get()->run();
+    stdplus::println(stderr, "Status: {}", static_cast<int>(status));
+    // All invocations tested with success code (State machine runs fine)
+    EXPECT_EQ(status, OperationStatus::Success);
+}
+
+TEST_F(BaseDiscoveryTest, StateMachineRunReceiveFails)
+{
+    // Fails at the second receive and hence discovery should fail and no more
+    // calls to be made to the interface
+    std::unique_ptr<BaseDiscoveryStateMachine> stateMachineBase =
+        std::make_unique<BaseDiscoveryStateMachine>(testFd, testDeviceId,
+                                                    testNetId, mockInterface);
+
+    EXPECT_CALL(*mockInterface, pldmSendAtNetwork(_, _, _, _, _))
+        .WillOnce(Return(PLDM_REQUESTER_SUCCESS))  // send for getTid
+        .WillOnce(Return(PLDM_REQUESTER_SUCCESS)); // send for getTypes
+
+    EXPECT_CALL(*mockInterface, pldmRecvAtNetwork(_, _, _, _, _, _))
+        .WillOnce(Return(PLDM_REQUESTER_SUCCESS)) // receive for getTid
+        .WillOnce(Return(
+            PLDM_REQUESTER_RESP_MSG_TOO_SMALL)); // receive fails for getTypes
+
+    OperationStatus status = stateMachineBase.get()->run();
+    stdplus::println(stderr, "Status: {}", static_cast<int>(status));
+
+    // Receive failure - hence discovery fails
+    EXPECT_EQ(status, OperationStatus::PldmRecvFailure);
+}
+
+TEST_F(BaseDiscoveryTest, StateMachineRunSendFails)
+{
+    // Fails at the first send and hence discovery should fail and no more
+    // calls to be made to the interface
+    std::unique_ptr<BaseDiscoveryStateMachine> stateMachineBase =
+        std::make_unique<BaseDiscoveryStateMachine>(testFd, testDeviceId,
+                                                    testNetId, mockInterface);
+
+    EXPECT_CALL(*mockInterface, pldmSendAtNetwork(_, _, _, _, _))
+        .WillOnce(Return(PLDM_REQUESTER_SEND_FAIL)); // send for getTid fails
+
+    // No receiveFrom calls are expected as it failed on first send
+
+    OperationStatus status = stateMachineBase.get()->run();
+    stdplus::println(stderr, "Status: {}", static_cast<int>(status));
+
+    // Send failure - hence discovery fails
+    EXPECT_EQ(status, OperationStatus::PldmSendFailure);
+}
diff --git a/tests/meson.build b/tests/meson.build
index 77b7918..cd2ed39 100644
--- a/tests/meson.build
+++ b/tests/meson.build
@@ -19,22 +19,13 @@
     endif
 endif
 
-test(
-    'test_utils',
-    executable(
-        'test_utils',
-        'rde_test.cpp',
-        dependencies: [gtest],
-        implicit_include_directories: false,
-    )
-)
+tests = [
+    'mctp_setup_test',
+    'discovery/base/base_disc_state_machine_test'
+]
 
-test(
-    'test_mctp_setup',
-    executable(
-        'test_mctp_setup',
-        'mctp_setup_test.cpp',
-        dependencies: [gtest, req_src, stdplus],
-        implicit_include_directories: true,
-    )
-)
+foreach t : tests
+  test(t, executable(t.underscorify(), t + '.cpp',
+                     implicit_include_directories: false,
+                     dependencies: [rded_lib_dep, req_src, stdplus, gtest, gmock]))
+endforeach
diff --git a/tests/pldm_interface_mock.hpp b/tests/pldm_interface_mock.hpp
new file mode 100644
index 0000000..54c8a7b
--- /dev/null
+++ b/tests/pldm_interface_mock.hpp
@@ -0,0 +1,21 @@
+#ifndef MOCK_PLDM_INTERFACE_HPP
+#define MOCK_PLDM_INTERFACE_HPP
+
+#include "libpldm/base.h"
+#include "libpldm/pldm.h"
+
+#include "interface/pldm_interface.hpp"
+
+#include "gmock/gmock.h"
+
+class MockPldmInterface : public PldmInterface
+{
+  public:
+    MOCK_METHOD(pldm_requester_rc_t, pldmSendAtNetwork,
+                (mctp_eid_t, int, int, const uint8_t*, size_t));
+
+    MOCK_METHOD(pldm_requester_rc_t, pldmRecvAtNetwork,
+                (mctp_eid_t, int, int, uint8_t, uint8_t**, size_t*));
+};
+
+#endif // MOCK_PLDM_INTERFACE_HPP
diff --git a/tests/rde_test.cpp b/tests/rde_test.cpp
deleted file mode 100644
index 1ad6631..0000000
--- a/tests/rde_test.cpp
+++ /dev/null
@@ -1,7 +0,0 @@
-#include <gtest/gtest.h>
-
-TEST(DummyTest, DummyTest)
-{
-    // TODO(@harshtya): Add tests for RDE
-    EXPECT_TRUE(true);
-}
diff --git a/util/state_machine/discovery/base/base_disc_state_machine.cpp b/util/state_machine/discovery/base/base_disc_state_machine.cpp
index 76f5725..23973d2 100644
--- a/util/state_machine/discovery/base/base_disc_state_machine.cpp
+++ b/util/state_machine/discovery/base/base_disc_state_machine.cpp
@@ -3,15 +3,14 @@
 #include "libpldm/base.h"
 #include "libpldm/pldm.h"
 
-#include "interface/pldm_interface.hpp"
-
 #include <stdplus/print.hpp>
 #include <util/common.hpp>
 
 BaseDiscoveryStateMachine::BaseDiscoveryStateMachine(
-    int fd, const std::string& deviceName, int netId) :
+    int fd, const std::string& deviceName, int netId,
+    std::shared_ptr<PldmInterface> pldmInterface) :
     fd(fd),
-    deviceName(deviceName), netId(netId)
+    deviceName(deviceName), netId(netId), pldmInterface(pldmInterface)
 {
     stdplus::print(stderr, "Initializing base discovery state machine...\n");
     this->initialized = true;
@@ -96,8 +95,8 @@
         return OperationStatus::EncodingRequestFailure;
     }
 
-    if (pldmSendAtNetwork(this->eid, this->netId, this->fd, requestMsg.data(),
-                          requestMsg.size()))
+    if (pldmInterface->pldmSendAtNetwork(this->eid, this->netId, this->fd,
+                                         requestMsg.data(), requestMsg.size()))
     {
         this->requesterStatus = StateMachineStatus::RequestFailed;
         return OperationStatus::PldmSendFailure;
@@ -108,8 +107,9 @@
     size_t responseMsgSize = sizeof(pldm_msg_hdr) + PLDM_GET_TID_RESP_BYTES;
     auto responsePtr = reinterpret_cast<struct pldm_msg*>(responseMsg);
 
-    if (pldmRecvAtNetwork(this->eid, this->netId, this->fd, this->instanceId,
-                          &responseMsg, &responseMsgSize))
+    if (pldmInterface->pldmRecvAtNetwork(this->eid, this->netId, this->fd,
+                                         this->instanceId, &responseMsg,
+                                         &responseMsgSize))
     {
         this->requesterStatus = StateMachineStatus::RequestFailed;
         return OperationStatus::PldmRecvFailure;
@@ -145,8 +145,8 @@
         return OperationStatus::EncodingRequestFailure;
     }
 
-    if (pldmSendAtNetwork(this->eid, this->netId, this->fd, requestMsg.data(),
-                          requestMsg.size()))
+    if (pldmInterface->pldmSendAtNetwork(this->eid, this->netId, this->fd,
+                                         requestMsg.data(), requestMsg.size()))
     {
         this->requesterStatus = StateMachineStatus::RequestFailed;
         return OperationStatus::PldmSendFailure;
@@ -157,8 +157,9 @@
     size_t responseMsgSize = response.size();
     auto responsePtr = reinterpret_cast<struct pldm_msg*>(responseMsg);
 
-    if (pldmRecvAtNetwork(this->eid, this->netId, this->fd, this->instanceId,
-                          &responseMsg, &responseMsgSize))
+    if (pldmInterface->pldmRecvAtNetwork(this->eid, this->netId, this->fd,
+                                         this->instanceId, &responseMsg,
+                                         &responseMsgSize))
     {
         this->requesterStatus = StateMachineStatus::RequestFailed;
         return OperationStatus::PldmRecvFailure;
@@ -198,8 +199,8 @@
         return OperationStatus::EncodingRequestFailure;
     }
 
-    if (pldmSendAtNetwork(this->eid, this->netId, this->fd, requestMsg.data(),
-                          requestMsg.size()))
+    if (pldmInterface->pldmSendAtNetwork(this->eid, this->netId, this->fd,
+                                         requestMsg.data(), requestMsg.size()))
     {
         this->requesterStatus = StateMachineStatus::RequestFailed;
         return OperationStatus::PldmSendFailure;
@@ -211,8 +212,9 @@
     size_t responseMsgSize = response.size();
     auto responsePtr = reinterpret_cast<struct pldm_msg*>(responseMsg);
 
-    if (pldmRecvAtNetwork(this->eid, this->netId, this->fd, this->instanceId,
-                          &responseMsg, &responseMsgSize))
+    if (pldmInterface->pldmRecvAtNetwork(this->eid, this->netId, this->fd,
+                                         this->instanceId, &responseMsg,
+                                         &responseMsgSize))
     {
         this->requesterStatus = StateMachineStatus::RequestFailed;
         return OperationStatus::PldmRecvFailure;
@@ -254,8 +256,8 @@
         return OperationStatus::EncodingRequestFailure;
     }
 
-    if (pldmSendAtNetwork(this->eid, this->netId, this->fd, requestMsg.data(),
-                          requestMsg.size()))
+    if (pldmInterface->pldmSendAtNetwork(this->eid, this->netId, this->fd,
+                                         requestMsg.data(), requestMsg.size()))
     {
         this->requesterStatus = StateMachineStatus::RequestFailed;
         return OperationStatus::PldmSendFailure;
@@ -266,8 +268,9 @@
     size_t responseMsgSize = response.size();
     auto responsePtr = reinterpret_cast<struct pldm_msg*>(responseMsg);
 
-    if (pldmRecvAtNetwork(this->eid, this->netId, this->fd, this->instanceId,
-                          &responseMsg, &responseMsgSize))
+    if (pldmInterface->pldmRecvAtNetwork(this->eid, this->netId, this->fd,
+                                         this->instanceId, &responseMsg,
+                                         &responseMsgSize))
     {
         this->requesterStatus = StateMachineStatus::RequestFailed;
         return OperationStatus::PldmRecvFailure;
diff --git a/util/state_machine/discovery/base/base_disc_state_machine.hpp b/util/state_machine/discovery/base/base_disc_state_machine.hpp
index 55091bb..d773fcc 100644
--- a/util/state_machine/discovery/base/base_disc_state_machine.hpp
+++ b/util/state_machine/discovery/base/base_disc_state_machine.hpp
@@ -4,9 +4,11 @@
 #include "libpldm/base.h"
 #include "libpldm/pldm.h"
 
+#include "interface/pldm_interface.hpp"
 #include "state_machine_factory.hpp"
 
 #include <array>
+#include <memory>
 #include <optional>
 #include <unordered_map>
 
@@ -26,7 +28,8 @@
 class BaseDiscoveryStateMachine : public StateMachineFactory
 {
   public:
-    BaseDiscoveryStateMachine(int fd, const std::string& deviceName, int netId);
+    BaseDiscoveryStateMachine(int fd, const std::string& deviceName, int netId,
+                              std::shared_ptr<PldmInterface> pldmInterface);
 
     OperationStatus run() override;
 
@@ -82,6 +85,8 @@
     std::array<std::array<uint8_t, PLDM_MAX_CMDS_PER_TYPE>, PLDM_MAX_TYPES>
         pldmCommands;
     std::array<ver32_t, PLDM_MAX_TYPES> pldmVersions;
+
+    std::shared_ptr<PldmInterface> pldmInterface;
 };
 
 #endif // BASE_DISC_STATE_MACHINE_HPP