blob: 1ac7279a887e55df7c2d5d1aea2408eb6194659a [file] [edit]
/*
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
* All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
// NSMD_SOCKET_REDIRECT must NOT be defined here — the macros would expand
// MOCK_METHOD-generated method definitions that share syscall names.
#include "nsmd/socketIo.hpp"
#include <libnsm/base.h>
#include <cstdint>
#include <queue>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
/**
* @brief GMock implementation of SocketIoInterface.
*
* A single NiceMock<MockSocketIo> instance is shared across all tests in a
* binary (see mockSocketIo.cpp). Access it via getMockSocketIo().
*/
class MockSocketIo : public SocketIoInterface
{
public:
MOCK_METHOD(int, socket, (int, int, int), (override));
MOCK_METHOD(int, bind, (int, const struct sockaddr*, socklen_t),
(override));
MOCK_METHOD(int, connect, (int, const struct sockaddr*, socklen_t),
(override));
MOCK_METHOD(int, close, (int), (override));
MOCK_METHOD(ssize_t, send, (int, const void*, size_t, int), (override));
MOCK_METHOD(ssize_t, sendto,
(int, const void*, size_t, int, const struct sockaddr*,
socklen_t),
(override));
MOCK_METHOD(ssize_t, sendmsg, (int, const struct msghdr*, int), (override));
MOCK_METHOD(ssize_t, write, (int, const void*, size_t), (override));
MOCK_METHOD(ssize_t, recv, (int, void*, size_t, int), (override));
MOCK_METHOD(ssize_t, recvfrom,
(int, void*, size_t, int, struct sockaddr*, socklen_t*),
(override));
MOCK_METHOD(ssize_t, recvmsg, (int, struct msghdr*, int), (override));
MOCK_METHOD(int, poll, (struct pollfd*, nfds_t, int), (override));
MOCK_METHOD(int, getsockopt, (int, int, int, void*, socklen_t*),
(override));
MOCK_METHOD(int, setsockopt, (int, int, int, const void*, socklen_t),
(override));
MOCK_METHOD(int, ioctl, (int, unsigned long, void*), (override));
};
/** Returns the MockSocketIo singleton used by getSocketIo() in test binaries.
*/
MockSocketIo& getMockSocketIo();
/**
* @brief Captures sent MCTP messages and provides canned responses.
*/
class MctpMessageCapture
{
public:
struct SentMessage
{
uint8_t tag;
uint8_t eid;
uint8_t msgType;
std::vector<uint8_t> payload;
};
struct Response
{
uint8_t eid;
uint8_t tag;
std::vector<uint8_t> mctpPayload; // 3-byte header + NSM payload
};
void enqueueMctpResponse(uint8_t eid, uint8_t tag,
const std::vector<uint8_t>& nsmPayload)
{
Response resp;
resp.eid = eid;
resp.tag = tag;
resp.mctpPayload.push_back(tag);
resp.mctpPayload.push_back(eid);
resp.mctpPayload.push_back(0x7E); // VDM type
resp.mctpPayload.insert(resp.mctpPayload.end(), nsmPayload.begin(),
nsmPayload.end());
responses_.push(resp);
}
void recordSentMessage(uint8_t tag, uint8_t eid, uint8_t msgType,
const std::vector<uint8_t>& payload)
{
sent_.push_back({tag, eid, msgType, payload});
}
const std::vector<SentMessage>& getSentMessages() const
{
return sent_;
}
bool getNextResponse(Response& resp)
{
if (responses_.empty())
{
return false;
}
resp = responses_.front();
responses_.pop();
return true;
}
void clearMessages()
{
sent_.clear();
while (!responses_.empty())
{
responses_.pop();
}
}
private:
std::queue<Response> responses_;
std::vector<SentMessage> sent_;
};
/**
* @brief Helper for building NSM test messages.
*/
class NsmMessageBuilder
{
public:
static std::vector<uint8_t> successResponse(uint8_t cmdId,
uint8_t instanceId)
{
std::vector<uint8_t> resp;
resp.push_back(0xDE);
resp.push_back(0x10);
resp.push_back((instanceId & 0x1F));
resp.push_back(0x00);
resp.push_back(cmdId);
resp.push_back(0x00);
resp.push_back(0x00);
resp.push_back(0x00);
resp.push_back(0x00);
return resp;
}
static std::vector<uint8_t> errorResponse(uint8_t cmdId, uint8_t instanceId,
uint8_t errorCode)
{
std::vector<uint8_t> resp;
resp.push_back(0xDE);
resp.push_back(0x10);
resp.push_back((instanceId & 0x1F));
resp.push_back(0x00);
resp.push_back(cmdId);
resp.push_back(errorCode);
resp.push_back(0x00);
resp.push_back(0x00);
resp.push_back(0x00);
return resp;
}
static std::vector<uint8_t>
request(uint8_t cmdId, uint8_t instanceId,
const std::vector<uint8_t>& payload = {})
{
std::vector<uint8_t> req;
req.push_back(0xDE);
req.push_back(0x10);
req.push_back(0x80 | (instanceId & 0x1F));
req.push_back(0x00);
req.push_back(cmdId);
req.insert(req.end(), payload.begin(), payload.end());
return req;
}
};
/**
* @brief Base fixture for MCTP socket tests.
*
* Provides access to the shared MockSocketIo singleton and resets it between
* tests. No injection/restore needed — the mock is always active in test
* binaries.
*/
class MctpTestFixture : public ::testing::Test
{
protected:
MockSocketIo* mockIo_ = &getMockSocketIo();
std::unique_ptr<MctpMessageCapture> capture_;
void SetUp() override
{
capture_ = std::make_unique<MctpMessageCapture>();
}
void TearDown() override
{
::testing::Mock::VerifyAndClearExpectations(mockIo_);
capture_->clearMessages();
}
void expectSocketCreation(int fakeFd = 42, int domain = AF_MCTP)
{
using ::testing::_;
using ::testing::Return;
EXPECT_CALL(*mockIo_, socket(domain, _, _)).WillOnce(Return(fakeFd));
}
void expectMctpSendMsg(int fd, uint8_t expectedEid, uint8_t expectedTag)
{
using ::testing::_;
EXPECT_CALL(*mockIo_, sendmsg(fd, _, 0))
.WillOnce([=, this](int, const struct msghdr* msg, int) {
EXPECT_EQ(msg->msg_iovlen, 2);
struct iovec* iov = msg->msg_iov;
EXPECT_EQ(iov[0].iov_len, 3);
const uint8_t* hdr = static_cast<const uint8_t*>(iov[0].iov_base);
EXPECT_EQ(hdr[0], expectedTag);
EXPECT_EQ(hdr[1], expectedEid);
EXPECT_EQ(hdr[2], 0x7E);
const uint8_t* payload =
static_cast<const uint8_t*>(iov[1].iov_base);
std::vector<uint8_t> payloadVec(payload, payload + iov[1].iov_len);
capture_->recordSentMessage(expectedTag, expectedEid, 0x7E,
payloadVec);
return iov[0].iov_len + iov[1].iov_len;
});
}
void expectMctpRecvWithPeek(int fd, const std::vector<uint8_t>& response)
{
using ::testing::_;
using ::testing::Return;
EXPECT_CALL(*mockIo_, recv(fd, nullptr, 0, MSG_PEEK | MSG_TRUNC))
.WillOnce(Return(response.size()));
EXPECT_CALL(*mockIo_, recvmsg(fd, _, 0))
.WillOnce([=](int, struct msghdr* msg, int) {
size_t offset = 0;
for (size_t i = 0; i < msg->msg_iovlen && offset < response.size();
i++)
{
size_t toCopy = std::min(msg->msg_iov[i].iov_len,
response.size() - offset);
memcpy(msg->msg_iov[i].iov_base, response.data() + offset,
toCopy);
offset += toCopy;
}
return static_cast<ssize_t>(response.size());
});
}
void expectPoll(int fd, int timeout, int result = 1)
{
using ::testing::_;
using ::testing::Return;
EXPECT_CALL(*mockIo_, poll(_, 1, timeout))
.WillOnce([=](struct pollfd* fds, nfds_t, int) {
EXPECT_EQ(fds[0].fd, fd);
EXPECT_EQ(fds[0].events, POLLIN);
if (result > 0)
{
fds[0].revents = POLLIN;
}
return result;
});
}
};