blob: b857ec361798d984817cecfc20fa17387bce33ca [file] [log] [blame]
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "libusb_mock.hpp"
#include "message_usb.hpp"
#include <algorithm>
#include <memory>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
namespace google
{
namespace hoth
{
using ::testing::_;
using ::testing::AnyNumber;
using ::testing::DoAll;
using ::testing::Ge;
using ::testing::Return;
using ::testing::SetArgPointee;
ACTION_P(CopyNumbers, nums)
{
const auto c = static_cast<decltype(arg2)>(nums.size());
if (arg2 < c)
{
return LIBUSB_ERROR_OVERFLOW;
}
const auto copied = std::min(arg2, c);
if (nums.data() != nullptr) {
memcpy(arg1, nums.data(), copied);
}
return copied;
}
class MessageUSBTest : public ::testing::Test
{
protected:
MessageUSBTest()
{
EXPECT_CALL(libusb, init(_))
.WillRepeatedly(DoAll(SetArgPointee<0>(ctx), Return(0)));
EXPECT_CALL(libusb, exit(ctx)).Times(AnyNumber());
EXPECT_CALL(libusb, get_device_list(ctx, _))
.WillRepeatedly([&](libusb_context*, libusb_device*** list) {
*list = devices.data();
return 0;
});
EXPECT_CALL(libusb, free_device_list(_, 0)).Times(AnyNumber());
}
void add_device(int devn, int bus, const std::vector<uint8_t>& ports)
{
// NOLINTNEXTLINE(performance-no-int-to-ptr)
libusb_device* dev = reinterpret_cast<libusb_device*>(devn);
EXPECT_CALL(libusb, get_bus_number(dev)).WillRepeatedly(Return(bus));
EXPECT_CALL(libusb, get_port_numbers(dev, _, _))
.WillRepeatedly(CopyNumbers(ports));
EXPECT_CALL(libusb, unref_device(dev)).Times(AnyNumber());
devices[devices.size() - 1] = dev;
devices.push_back(nullptr);
}
std::vector<libusb_device*> devices = {nullptr};
::testing::StrictMock<LibusbMock> libusb;
libusb_context* const ctx = reinterpret_cast<libusb_context*>(1);
};
TEST_F(MessageUSBTest, ParseDeviceString)
{
// Depends on distinguishing invalid bus+port strings from valid ones
// based on the MessageUSB throwing invalid_argument for bad input names.
EXPECT_THROW(MessageUSB("", &libusb), std::invalid_argument);
EXPECT_THROW(MessageUSB("1", &libusb), std::invalid_argument);
EXPECT_THROW(MessageUSB("a-1", &libusb), std::invalid_argument);
EXPECT_THROW(MessageUSB("1-a", &libusb), std::invalid_argument);
EXPECT_THROW(MessageUSB("1-.1", &libusb), std::invalid_argument);
EXPECT_THROW(MessageUSB("1-1.3-", &libusb), std::invalid_argument);
EXPECT_THROW(MessageUSB("1-3", &libusb), std::runtime_error);
EXPECT_THROW(MessageUSB("1-1.3", &libusb), std::runtime_error);
}
TEST_F(MessageUSBTest, MatchDevice)
{
add_device(2, 1, {});
add_device(3, 1, {1, 2});
add_device(4, 1, {2, 1});
// Depends on the mock never being asked for the active config_descriptor
// when the device could not be found. Only once it finds a valid device
// during lookup will it check the config_descriptor.
EXPECT_THROW(MessageUSB("1-3", &libusb), std::runtime_error);
EXPECT_THROW(MessageUSB("1-1", &libusb), std::runtime_error);
EXPECT_THROW(MessageUSB("1-1.1", &libusb), std::runtime_error);
EXPECT_THROW(MessageUSB("1-1.2.1", &libusb), std::runtime_error);
libusb_config_descriptor cd;
cd.bNumInterfaces = 0;
EXPECT_CALL(libusb, get_active_config_descriptor(
reinterpret_cast<libusb_device*>(3), _))
.WillOnce(DoAll(SetArgPointee<1>(&cd), Return(0)));
EXPECT_THROW(MessageUSB("1-1.2", &libusb), std::runtime_error);
}
class MessageUSBMemberTest : public MessageUSBTest
{
protected:
MessageUSBMemberTest()
{
// ed[0] is intentionally a non BULK type transfer endpoint
// to ensure we filter correctly.
ed[0].bmAttributes = LIBUSB_TRANSFER_TYPE_CONTROL;
ed[1].bmAttributes = LIBUSB_TRANSFER_TYPE_BULK;
ed[1].bEndpointAddress = 0x03 | LIBUSB_ENDPOINT_OUT;
ed[1].wMaxPacketSize = sizeof(MessageUSB::ReqHeader) + 2;
ed[2].bmAttributes = LIBUSB_TRANSFER_TYPE_BULK;
ed[2].bEndpointAddress = 0x03 | LIBUSB_ENDPOINT_IN;
ed[2].wMaxPacketSize = sizeof(MessageUSB::RspHeader) + 2;
// id[0] is intentionally not a hoth mailnbox to ensure we filter
// for the correct interface.
id[0].bInterfaceClass = MessageUSB::kMailboxClass;
id[0].bInterfaceSubClass = MessageUSB::kMailboxSubClasses[1] - 1;
id[0].bInterfaceProtocol = MessageUSB::kMailboxProtocol;
id[1].endpoint = ed;
id[1].bNumEndpoints = 3;
id[1].bInterfaceClass = MessageUSB::kMailboxClass;
id[1].bInterfaceSubClass = MessageUSB::kMailboxSubClasses[1];
id[1].bInterfaceProtocol = MessageUSB::kMailboxProtocol;
id[1].bInterfaceNumber = 5;
// intf[0] is intentionally empty to ensure we are looping over these
intf[0].num_altsetting = 0;
intf[1].altsetting = id;
intf[1].num_altsetting = 2;
cd.interface = intf;
cd.bNumInterfaces = 2;
add_device(2, 1, {1, 1});
EXPECT_CALL(libusb, get_active_config_descriptor(devices[0], _))
.WillOnce(DoAll(SetArgPointee<1>(&cd), Return(0)));
EXPECT_CALL(libusb, open(devices[0], _))
.WillOnce(DoAll(SetArgPointee<1>(handle), Return(0)));
EXPECT_CALL(libusb, close(handle)).Times(1);
EXPECT_CALL(libusb, claim_interface(handle, id[1].bInterfaceNumber))
.WillOnce(Return(0));
EXPECT_CALL(libusb, release_interface(handle, id[1].bInterfaceNumber))
.WillOnce(Return(0));
msg = std::make_unique<MessageUSB>("1-1.1", &libusb, /*unit_test=*/true);
}
libusb_device_handle* const handle =
reinterpret_cast<libusb_device_handle*>(10);
libusb_endpoint_descriptor ed[3];
libusb_config_descriptor cd;
libusb_interface intf[2];
libusb_interface_descriptor id[2];
std::unique_ptr<MessageUSB> msg = nullptr;
};
TEST_F(MessageUSBMemberTest, CompleteOpen)
{}
ACTION_P(CompareBulk, bulk)
{
EXPECT_EQ(arg3, bulk.size());
EXPECT_EQ(0, memcmp(bulk.data(), arg2, bulk.size()));
*arg4 = bulk.size();
}
ACTION_P(PopulateBulk, out)
{
if (arg3 < static_cast<decltype(arg3)>(out.size()))
{
throw std::invalid_argument("populate");
}
memcpy(arg2, out.data(), out.size());
*arg4 = out.size();
return 0;
}
// Ensure truncated request transfers result in errors
TEST_F(MessageUSBMemberTest, SendPartialOut)
{
std::vector<uint8_t> buf = {1, 2, 3, 4, 5};
MessageUSB::ReqHeader req;
req.type = MessageUSB::ReqHeader::kWrite;
req.offset = 3;
req.length = 2;
std::vector<uint8_t> bulk(sizeof(req) + req.length);
memcpy(bulk.data(), &req, sizeof(req));
memcpy(bulk.data() + sizeof(req), buf.data(), req.length);
EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _,
bulk.size(), _, 0))
.WillOnce(DoAll(CompareBulk(bulk), SetArgPointee<4>(bulk.size() - 1),
Return(0)));
EXPECT_THROW(msg->send(buf.data(), buf.size(), 3), std::runtime_error);
}
// Ensure truncated response transfers result in errors
TEST_F(MessageUSBMemberTest, SendPartialIn)
{
std::vector<uint8_t> buf = {1, 2, 3, 4, 5};
MessageUSB::ReqHeader req;
req.type = MessageUSB::ReqHeader::kWrite;
req.offset = 3;
req.length = 2;
std::vector<uint8_t> bulk(sizeof(req) + req.length);
memcpy(bulk.data(), &req, sizeof(req));
memcpy(bulk.data() + sizeof(req), buf.data(), req.length);
EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _,
bulk.size(), _, 0))
.WillOnce(DoAll(CompareBulk(bulk), Return(0)));
EXPECT_CALL(libusb, bulk_transfer(handle, ed[2].bEndpointAddress, _,
Ge(sizeof(MessageUSB::RspHeader)), _, 0))
.WillOnce(DoAll(SetArgPointee<4>(sizeof(MessageUSB::RspHeader) - 1),
Return(0)));
EXPECT_THROW(msg->send(buf.data(), buf.size(), 3), std::runtime_error);
}
// Ensure non-zero response status codes result in errors
TEST_F(MessageUSBMemberTest, SendBadStatus)
{
std::vector<uint8_t> buf = {1, 2, 3, 4, 5};
MessageUSB::ReqHeader req;
req.type = MessageUSB::ReqHeader::kWrite;
req.offset = 3;
req.length = 2;
std::vector<uint8_t> bulk(sizeof(req) + req.length);
memcpy(bulk.data(), &req, sizeof(req));
memcpy(bulk.data() + sizeof(req), buf.data(), req.length);
EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _,
bulk.size(), _, 0))
.WillOnce(DoAll(CompareBulk(bulk), Return(0)));
MessageUSB::RspHeader rsp;
rsp.status = 1;
bulk.resize(sizeof(rsp));
memcpy(bulk.data(), &rsp, sizeof(rsp));
EXPECT_CALL(libusb, bulk_transfer(handle, ed[2].bEndpointAddress, _,
Ge(sizeof(rsp)), _, 0))
.WillOnce(PopulateBulk(bulk));
EXPECT_THROW(msg->send(buf.data(), buf.size(), 3), std::runtime_error);
}
// Ensure responses that are too long get truncated
TEST_F(MessageUSBMemberTest, SendLongRsp)
{
std::vector<uint8_t> buf = {1};
MessageUSB::ReqHeader req;
req.type = MessageUSB::ReqHeader::kWrite;
req.offset = 3;
req.length = buf.size();
std::vector<uint8_t> bulk(sizeof(req) + req.length);
memcpy(bulk.data(), &req, sizeof(req));
memcpy(bulk.data() + sizeof(req), buf.data(), req.length);
EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _,
bulk.size(), _, 0))
.WillOnce(DoAll(CompareBulk(bulk), Return(0)));
MessageUSB::RspHeader rsp;
rsp.status = 0;
memcpy(bulk.data(), &rsp, sizeof(rsp));
EXPECT_CALL(libusb, bulk_transfer(handle, ed[2].bEndpointAddress, _,
Ge(bulk.size()), _, 0))
.WillOnce(PopulateBulk(bulk));
msg->send(buf.data(), buf.size(), 3);
}
// Ensure we can successfully send a chunked message if the USB channel
// isn't wide enough to do it in a single transfer request
TEST_F(MessageUSBMemberTest, SendMulti)
{
std::vector<uint8_t> buf = {1, 2, 3};
MessageUSB::ReqHeader req;
req.type = MessageUSB::ReqHeader::kWrite;
req.offset = 3;
req.length = 2;
std::vector<uint8_t> bulk(sizeof(req) + req.length);
memcpy(bulk.data(), &req, sizeof(req));
bulk[sizeof(req)] = buf[0];
bulk[sizeof(req) + 1] = buf[1];
EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _,
bulk.size(), _, 0))
.WillOnce(DoAll(CompareBulk(bulk), Return(0)));
MessageUSB::RspHeader rsp;
rsp.status = 0;
bulk.resize(sizeof(rsp));
memcpy(bulk.data(), &rsp, sizeof(rsp));
EXPECT_CALL(libusb, bulk_transfer(handle, ed[2].bEndpointAddress, _,
Ge(sizeof(rsp)), _, 0))
.WillRepeatedly(PopulateBulk(bulk));
req.offset = 5;
req.length = 1;
bulk.resize(sizeof(req) + req.length);
memcpy(bulk.data(), &req, sizeof(req));
bulk[sizeof(req)] = buf[2];
EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _,
bulk.size(), _, 0))
.WillOnce(DoAll(CompareBulk(bulk), Return(0)));
msg->send(buf.data(), buf.size(), 3);
}
// Ensure truncated request transfers result in errors
TEST_F(MessageUSBMemberTest, RecvPartialOut)
{
MessageUSB::ReqHeader req;
req.type = MessageUSB::ReqHeader::kRead;
req.offset = 3;
req.length = 2;
std::vector<uint8_t> bulk(sizeof(req));
memcpy(bulk.data(), &req, sizeof(req));
EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _,
bulk.size(), _, 0))
.WillOnce(DoAll(CompareBulk(bulk), SetArgPointee<4>(bulk.size() - 1),
Return(0)));
std::vector<uint8_t> recvd(3);
EXPECT_THROW(msg->recv(recvd.data(), recvd.size(), 3), std::runtime_error);
}
// Ensure truncated response transfers result in errors
TEST_F(MessageUSBMemberTest, RecvPartialIn)
{
MessageUSB::ReqHeader req;
req.type = MessageUSB::ReqHeader::kRead;
req.offset = 3;
req.length = 2;
std::vector<uint8_t> bulk(sizeof(req));
memcpy(bulk.data(), &req, sizeof(req));
EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _,
bulk.size(), _, 0))
.WillOnce(DoAll(CompareBulk(bulk), Return(0)));
EXPECT_CALL(libusb,
bulk_transfer(handle, ed[2].bEndpointAddress, _,
Ge(sizeof(MessageUSB::RspHeader) + req.length), _,
0))
.WillOnce(DoAll(SetArgPointee<4>(sizeof(MessageUSB::RspHeader) - 1),
Return(0)));
std::vector<uint8_t> recvd(3);
EXPECT_THROW(msg->recv(recvd.data(), recvd.size(), 3), std::runtime_error);
}
// Ensure extra response data is discarded
TEST_F(MessageUSBMemberTest, RecvLongRsp)
{
MessageUSB::ReqHeader req;
req.type = MessageUSB::ReqHeader::kRead;
req.offset = 3;
req.length = 1;
std::vector<uint8_t> bulk(sizeof(req));
memcpy(bulk.data(), &req, sizeof(req));
EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _,
bulk.size(), _, 0))
.WillOnce(DoAll(CompareBulk(bulk), Return(0)));
MessageUSB::RspHeader rsp;
rsp.status = 0;
bulk.resize(sizeof(rsp) + req.length + 1);
memcpy(bulk.data(), &rsp, sizeof(rsp));
bulk[sizeof(rsp)] = 3;
bulk[sizeof(rsp) + 1] = 4;
EXPECT_CALL(libusb,
bulk_transfer(handle, ed[2].bEndpointAddress, _,
Ge(sizeof(MessageUSB::RspHeader) + req.length), _,
0))
.WillOnce(PopulateBulk(bulk));
std::vector<uint8_t> recvd(1);
msg->recv(recvd.data(), recvd.size(), 3);
std::vector<uint8_t> expected{3};
EXPECT_EQ(expected, recvd);
}
// Ensure non-zero response status codes result in errors
TEST_F(MessageUSBMemberTest, RecvBadStatus)
{
MessageUSB::ReqHeader req;
req.type = MessageUSB::ReqHeader::kRead;
req.offset = 3;
req.length = 2;
std::vector<uint8_t> bulk(sizeof(req));
memcpy(bulk.data(), &req, sizeof(req));
EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _,
bulk.size(), _, 0))
.WillOnce(DoAll(CompareBulk(bulk), Return(0)));
MessageUSB::RspHeader rsp;
rsp.status = 1;
bulk.resize(sizeof(rsp) + req.length);
memcpy(bulk.data(), &rsp, sizeof(rsp));
EXPECT_CALL(libusb,
bulk_transfer(handle, ed[2].bEndpointAddress, _,
Ge(sizeof(MessageUSB::RspHeader) + req.length), _,
0))
.WillOnce(PopulateBulk(bulk));
std::vector<uint8_t> recvd(3);
EXPECT_THROW(msg->recv(recvd.data(), recvd.size(), 3), std::runtime_error);
}
// Ensure we can successfully receive chunked messages if the USB channel
// isn't wide enough to do it in a single transfer request.
TEST_F(MessageUSBMemberTest, RecvMulti)
{
testing::InSequence seq;
MessageUSB::ReqHeader req;
req.type = MessageUSB::ReqHeader::kRead;
req.offset = 3;
req.length = 2;
std::vector<uint8_t> bulk(sizeof(req));
memcpy(bulk.data(), &req, sizeof(req));
EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _,
bulk.size(), _, 0))
.WillOnce(DoAll(CompareBulk(bulk), Return(0)));
MessageUSB::RspHeader rsp;
rsp.status = 0;
bulk.resize(sizeof(rsp) + req.length);
memcpy(bulk.data(), &rsp, sizeof(rsp));
bulk[sizeof(rsp)] = 2;
bulk[sizeof(rsp) + 1] = 1;
EXPECT_CALL(libusb,
bulk_transfer(handle, ed[2].bEndpointAddress, _,
Ge(sizeof(MessageUSB::RspHeader) + req.length), _,
0))
.WillOnce(PopulateBulk(bulk));
req.offset = 5;
req.length = 1;
bulk.resize(sizeof(req));
memcpy(bulk.data(), &req, sizeof(req));
EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _,
bulk.size(), _, 0))
.WillOnce(DoAll(CompareBulk(bulk), Return(0)));
bulk.resize(sizeof(rsp) + req.length);
memcpy(bulk.data(), &rsp, sizeof(rsp));
bulk[sizeof(rsp)] = 3;
EXPECT_CALL(libusb,
bulk_transfer(handle, ed[2].bEndpointAddress, _,
Ge(sizeof(MessageUSB::RspHeader) + req.length), _,
0))
.WillOnce(PopulateBulk(bulk));
std::vector<uint8_t> recvd(3);
msg->recv(recvd.data(), recvd.size(), 3);
std::vector<uint8_t> expected{2, 1, 3};
EXPECT_EQ(expected, recvd);
}
} // namespace hoth
} // namespace google