| // Copyright 2024 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "libusb_mock.hpp" |
| #include "message_usb.hpp" |
| |
| #include <algorithm> |
| #include <memory> |
| #include <vector> |
| |
| #include <gmock/gmock.h> |
| #include <gtest/gtest.h> |
| |
| namespace google |
| { |
| namespace hoth |
| { |
| |
| using ::testing::_; |
| using ::testing::AnyNumber; |
| using ::testing::DoAll; |
| using ::testing::Ge; |
| using ::testing::Return; |
| using ::testing::SetArgPointee; |
| |
| ACTION_P(CopyNumbers, nums) |
| { |
| const auto c = static_cast<decltype(arg2)>(nums.size()); |
| if (arg2 < c) |
| { |
| return LIBUSB_ERROR_OVERFLOW; |
| } |
| const auto copied = std::min(arg2, c); |
| if (nums.data() != nullptr) { |
| memcpy(arg1, nums.data(), copied); |
| } |
| return copied; |
| } |
| |
| class MessageUSBTest : public ::testing::Test |
| { |
| protected: |
| MessageUSBTest() |
| { |
| EXPECT_CALL(libusb, init(_)) |
| .WillRepeatedly(DoAll(SetArgPointee<0>(ctx), Return(0))); |
| EXPECT_CALL(libusb, exit(ctx)).Times(AnyNumber()); |
| EXPECT_CALL(libusb, get_device_list(ctx, _)) |
| .WillRepeatedly([&](libusb_context*, libusb_device*** list) { |
| *list = devices.data(); |
| return 0; |
| }); |
| EXPECT_CALL(libusb, free_device_list(_, 0)).Times(AnyNumber()); |
| } |
| |
| void add_device(int devn, int bus, const std::vector<uint8_t>& ports) |
| { |
| // NOLINTNEXTLINE(performance-no-int-to-ptr) |
| libusb_device* dev = reinterpret_cast<libusb_device*>(devn); |
| EXPECT_CALL(libusb, get_bus_number(dev)).WillRepeatedly(Return(bus)); |
| EXPECT_CALL(libusb, get_port_numbers(dev, _, _)) |
| .WillRepeatedly(CopyNumbers(ports)); |
| EXPECT_CALL(libusb, unref_device(dev)).Times(AnyNumber()); |
| devices[devices.size() - 1] = dev; |
| devices.push_back(nullptr); |
| } |
| |
| std::vector<libusb_device*> devices = {nullptr}; |
| ::testing::StrictMock<LibusbMock> libusb; |
| libusb_context* const ctx = reinterpret_cast<libusb_context*>(1); |
| }; |
| |
| TEST_F(MessageUSBTest, ParseDeviceString) |
| { |
| // Depends on distinguishing invalid bus+port strings from valid ones |
| // based on the MessageUSB throwing invalid_argument for bad input names. |
| EXPECT_THROW(MessageUSB("", &libusb), std::invalid_argument); |
| EXPECT_THROW(MessageUSB("1", &libusb), std::invalid_argument); |
| EXPECT_THROW(MessageUSB("a-1", &libusb), std::invalid_argument); |
| EXPECT_THROW(MessageUSB("1-a", &libusb), std::invalid_argument); |
| EXPECT_THROW(MessageUSB("1-.1", &libusb), std::invalid_argument); |
| EXPECT_THROW(MessageUSB("1-1.3-", &libusb), std::invalid_argument); |
| |
| EXPECT_THROW(MessageUSB("1-3", &libusb), std::runtime_error); |
| EXPECT_THROW(MessageUSB("1-1.3", &libusb), std::runtime_error); |
| } |
| |
| TEST_F(MessageUSBTest, MatchDevice) |
| { |
| add_device(2, 1, {}); |
| add_device(3, 1, {1, 2}); |
| add_device(4, 1, {2, 1}); |
| |
| // Depends on the mock never being asked for the active config_descriptor |
| // when the device could not be found. Only once it finds a valid device |
| // during lookup will it check the config_descriptor. |
| EXPECT_THROW(MessageUSB("1-3", &libusb), std::runtime_error); |
| EXPECT_THROW(MessageUSB("1-1", &libusb), std::runtime_error); |
| EXPECT_THROW(MessageUSB("1-1.1", &libusb), std::runtime_error); |
| EXPECT_THROW(MessageUSB("1-1.2.1", &libusb), std::runtime_error); |
| |
| libusb_config_descriptor cd; |
| cd.bNumInterfaces = 0; |
| EXPECT_CALL(libusb, get_active_config_descriptor( |
| reinterpret_cast<libusb_device*>(3), _)) |
| .WillOnce(DoAll(SetArgPointee<1>(&cd), Return(0))); |
| EXPECT_THROW(MessageUSB("1-1.2", &libusb), std::runtime_error); |
| } |
| |
| class MessageUSBMemberTest : public MessageUSBTest |
| { |
| protected: |
| MessageUSBMemberTest() |
| { |
| // ed[0] is intentionally a non BULK type transfer endpoint |
| // to ensure we filter correctly. |
| ed[0].bmAttributes = LIBUSB_TRANSFER_TYPE_CONTROL; |
| ed[1].bmAttributes = LIBUSB_TRANSFER_TYPE_BULK; |
| ed[1].bEndpointAddress = 0x03 | LIBUSB_ENDPOINT_OUT; |
| ed[1].wMaxPacketSize = sizeof(MessageUSB::ReqHeader) + 2; |
| ed[2].bmAttributes = LIBUSB_TRANSFER_TYPE_BULK; |
| ed[2].bEndpointAddress = 0x03 | LIBUSB_ENDPOINT_IN; |
| ed[2].wMaxPacketSize = sizeof(MessageUSB::RspHeader) + 2; |
| |
| // id[0] is intentionally not a hoth mailnbox to ensure we filter |
| // for the correct interface. |
| id[0].bInterfaceClass = MessageUSB::kMailboxClass; |
| id[0].bInterfaceSubClass = MessageUSB::kMailboxSubClasses[1] - 1; |
| id[0].bInterfaceProtocol = MessageUSB::kMailboxProtocol; |
| id[1].endpoint = ed; |
| id[1].bNumEndpoints = 3; |
| id[1].bInterfaceClass = MessageUSB::kMailboxClass; |
| id[1].bInterfaceSubClass = MessageUSB::kMailboxSubClasses[1]; |
| id[1].bInterfaceProtocol = MessageUSB::kMailboxProtocol; |
| id[1].bInterfaceNumber = 5; |
| |
| // intf[0] is intentionally empty to ensure we are looping over these |
| intf[0].num_altsetting = 0; |
| intf[1].altsetting = id; |
| intf[1].num_altsetting = 2; |
| cd.interface = intf; |
| cd.bNumInterfaces = 2; |
| |
| add_device(2, 1, {1, 1}); |
| EXPECT_CALL(libusb, get_active_config_descriptor(devices[0], _)) |
| .WillOnce(DoAll(SetArgPointee<1>(&cd), Return(0))); |
| EXPECT_CALL(libusb, open(devices[0], _)) |
| .WillOnce(DoAll(SetArgPointee<1>(handle), Return(0))); |
| EXPECT_CALL(libusb, close(handle)).Times(1); |
| EXPECT_CALL(libusb, claim_interface(handle, id[1].bInterfaceNumber)) |
| .WillOnce(Return(0)); |
| EXPECT_CALL(libusb, release_interface(handle, id[1].bInterfaceNumber)) |
| .WillOnce(Return(0)); |
| |
| msg = std::make_unique<MessageUSB>("1-1.1", &libusb, /*unit_test=*/true); |
| } |
| |
| libusb_device_handle* const handle = |
| reinterpret_cast<libusb_device_handle*>(10); |
| libusb_endpoint_descriptor ed[3]; |
| libusb_config_descriptor cd; |
| libusb_interface intf[2]; |
| libusb_interface_descriptor id[2]; |
| std::unique_ptr<MessageUSB> msg = nullptr; |
| }; |
| |
| TEST_F(MessageUSBMemberTest, CompleteOpen) |
| {} |
| |
| ACTION_P(CompareBulk, bulk) |
| { |
| EXPECT_EQ(arg3, bulk.size()); |
| EXPECT_EQ(0, memcmp(bulk.data(), arg2, bulk.size())); |
| *arg4 = bulk.size(); |
| } |
| |
| ACTION_P(PopulateBulk, out) |
| { |
| if (arg3 < static_cast<decltype(arg3)>(out.size())) |
| { |
| throw std::invalid_argument("populate"); |
| } |
| memcpy(arg2, out.data(), out.size()); |
| *arg4 = out.size(); |
| return 0; |
| } |
| |
| // Ensure truncated request transfers result in errors |
| TEST_F(MessageUSBMemberTest, SendPartialOut) |
| { |
| std::vector<uint8_t> buf = {1, 2, 3, 4, 5}; |
| MessageUSB::ReqHeader req; |
| req.type = MessageUSB::ReqHeader::kWrite; |
| req.offset = 3; |
| req.length = 2; |
| std::vector<uint8_t> bulk(sizeof(req) + req.length); |
| memcpy(bulk.data(), &req, sizeof(req)); |
| memcpy(bulk.data() + sizeof(req), buf.data(), req.length); |
| EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _, |
| bulk.size(), _, 0)) |
| .WillOnce(DoAll(CompareBulk(bulk), SetArgPointee<4>(bulk.size() - 1), |
| Return(0))); |
| EXPECT_THROW(msg->send(buf.data(), buf.size(), 3), std::runtime_error); |
| } |
| |
| // Ensure truncated response transfers result in errors |
| TEST_F(MessageUSBMemberTest, SendPartialIn) |
| { |
| std::vector<uint8_t> buf = {1, 2, 3, 4, 5}; |
| MessageUSB::ReqHeader req; |
| req.type = MessageUSB::ReqHeader::kWrite; |
| req.offset = 3; |
| req.length = 2; |
| std::vector<uint8_t> bulk(sizeof(req) + req.length); |
| memcpy(bulk.data(), &req, sizeof(req)); |
| memcpy(bulk.data() + sizeof(req), buf.data(), req.length); |
| EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _, |
| bulk.size(), _, 0)) |
| .WillOnce(DoAll(CompareBulk(bulk), Return(0))); |
| EXPECT_CALL(libusb, bulk_transfer(handle, ed[2].bEndpointAddress, _, |
| Ge(sizeof(MessageUSB::RspHeader)), _, 0)) |
| .WillOnce(DoAll(SetArgPointee<4>(sizeof(MessageUSB::RspHeader) - 1), |
| Return(0))); |
| EXPECT_THROW(msg->send(buf.data(), buf.size(), 3), std::runtime_error); |
| } |
| |
| // Ensure non-zero response status codes result in errors |
| TEST_F(MessageUSBMemberTest, SendBadStatus) |
| { |
| std::vector<uint8_t> buf = {1, 2, 3, 4, 5}; |
| MessageUSB::ReqHeader req; |
| req.type = MessageUSB::ReqHeader::kWrite; |
| req.offset = 3; |
| req.length = 2; |
| std::vector<uint8_t> bulk(sizeof(req) + req.length); |
| memcpy(bulk.data(), &req, sizeof(req)); |
| memcpy(bulk.data() + sizeof(req), buf.data(), req.length); |
| EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _, |
| bulk.size(), _, 0)) |
| .WillOnce(DoAll(CompareBulk(bulk), Return(0))); |
| MessageUSB::RspHeader rsp; |
| rsp.status = 1; |
| bulk.resize(sizeof(rsp)); |
| memcpy(bulk.data(), &rsp, sizeof(rsp)); |
| EXPECT_CALL(libusb, bulk_transfer(handle, ed[2].bEndpointAddress, _, |
| Ge(sizeof(rsp)), _, 0)) |
| .WillOnce(PopulateBulk(bulk)); |
| EXPECT_THROW(msg->send(buf.data(), buf.size(), 3), std::runtime_error); |
| } |
| |
| // Ensure responses that are too long get truncated |
| TEST_F(MessageUSBMemberTest, SendLongRsp) |
| { |
| std::vector<uint8_t> buf = {1}; |
| MessageUSB::ReqHeader req; |
| req.type = MessageUSB::ReqHeader::kWrite; |
| req.offset = 3; |
| req.length = buf.size(); |
| std::vector<uint8_t> bulk(sizeof(req) + req.length); |
| memcpy(bulk.data(), &req, sizeof(req)); |
| memcpy(bulk.data() + sizeof(req), buf.data(), req.length); |
| EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _, |
| bulk.size(), _, 0)) |
| .WillOnce(DoAll(CompareBulk(bulk), Return(0))); |
| MessageUSB::RspHeader rsp; |
| rsp.status = 0; |
| memcpy(bulk.data(), &rsp, sizeof(rsp)); |
| EXPECT_CALL(libusb, bulk_transfer(handle, ed[2].bEndpointAddress, _, |
| Ge(bulk.size()), _, 0)) |
| .WillOnce(PopulateBulk(bulk)); |
| msg->send(buf.data(), buf.size(), 3); |
| } |
| |
| // Ensure we can successfully send a chunked message if the USB channel |
| // isn't wide enough to do it in a single transfer request |
| TEST_F(MessageUSBMemberTest, SendMulti) |
| { |
| std::vector<uint8_t> buf = {1, 2, 3}; |
| MessageUSB::ReqHeader req; |
| req.type = MessageUSB::ReqHeader::kWrite; |
| req.offset = 3; |
| req.length = 2; |
| std::vector<uint8_t> bulk(sizeof(req) + req.length); |
| memcpy(bulk.data(), &req, sizeof(req)); |
| bulk[sizeof(req)] = buf[0]; |
| bulk[sizeof(req) + 1] = buf[1]; |
| EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _, |
| bulk.size(), _, 0)) |
| .WillOnce(DoAll(CompareBulk(bulk), Return(0))); |
| MessageUSB::RspHeader rsp; |
| rsp.status = 0; |
| bulk.resize(sizeof(rsp)); |
| memcpy(bulk.data(), &rsp, sizeof(rsp)); |
| EXPECT_CALL(libusb, bulk_transfer(handle, ed[2].bEndpointAddress, _, |
| Ge(sizeof(rsp)), _, 0)) |
| .WillRepeatedly(PopulateBulk(bulk)); |
| |
| req.offset = 5; |
| req.length = 1; |
| bulk.resize(sizeof(req) + req.length); |
| memcpy(bulk.data(), &req, sizeof(req)); |
| bulk[sizeof(req)] = buf[2]; |
| EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _, |
| bulk.size(), _, 0)) |
| .WillOnce(DoAll(CompareBulk(bulk), Return(0))); |
| |
| msg->send(buf.data(), buf.size(), 3); |
| } |
| |
| // Ensure truncated request transfers result in errors |
| TEST_F(MessageUSBMemberTest, RecvPartialOut) |
| { |
| MessageUSB::ReqHeader req; |
| req.type = MessageUSB::ReqHeader::kRead; |
| req.offset = 3; |
| req.length = 2; |
| std::vector<uint8_t> bulk(sizeof(req)); |
| memcpy(bulk.data(), &req, sizeof(req)); |
| EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _, |
| bulk.size(), _, 0)) |
| .WillOnce(DoAll(CompareBulk(bulk), SetArgPointee<4>(bulk.size() - 1), |
| Return(0))); |
| std::vector<uint8_t> recvd(3); |
| EXPECT_THROW(msg->recv(recvd.data(), recvd.size(), 3), std::runtime_error); |
| } |
| |
| // Ensure truncated response transfers result in errors |
| TEST_F(MessageUSBMemberTest, RecvPartialIn) |
| { |
| MessageUSB::ReqHeader req; |
| req.type = MessageUSB::ReqHeader::kRead; |
| req.offset = 3; |
| req.length = 2; |
| std::vector<uint8_t> bulk(sizeof(req)); |
| memcpy(bulk.data(), &req, sizeof(req)); |
| EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _, |
| bulk.size(), _, 0)) |
| .WillOnce(DoAll(CompareBulk(bulk), Return(0))); |
| EXPECT_CALL(libusb, |
| bulk_transfer(handle, ed[2].bEndpointAddress, _, |
| Ge(sizeof(MessageUSB::RspHeader) + req.length), _, |
| 0)) |
| .WillOnce(DoAll(SetArgPointee<4>(sizeof(MessageUSB::RspHeader) - 1), |
| Return(0))); |
| std::vector<uint8_t> recvd(3); |
| EXPECT_THROW(msg->recv(recvd.data(), recvd.size(), 3), std::runtime_error); |
| } |
| |
| // Ensure extra response data is discarded |
| TEST_F(MessageUSBMemberTest, RecvLongRsp) |
| { |
| MessageUSB::ReqHeader req; |
| req.type = MessageUSB::ReqHeader::kRead; |
| req.offset = 3; |
| req.length = 1; |
| std::vector<uint8_t> bulk(sizeof(req)); |
| memcpy(bulk.data(), &req, sizeof(req)); |
| EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _, |
| bulk.size(), _, 0)) |
| .WillOnce(DoAll(CompareBulk(bulk), Return(0))); |
| MessageUSB::RspHeader rsp; |
| rsp.status = 0; |
| bulk.resize(sizeof(rsp) + req.length + 1); |
| memcpy(bulk.data(), &rsp, sizeof(rsp)); |
| bulk[sizeof(rsp)] = 3; |
| bulk[sizeof(rsp) + 1] = 4; |
| EXPECT_CALL(libusb, |
| bulk_transfer(handle, ed[2].bEndpointAddress, _, |
| Ge(sizeof(MessageUSB::RspHeader) + req.length), _, |
| 0)) |
| .WillOnce(PopulateBulk(bulk)); |
| std::vector<uint8_t> recvd(1); |
| msg->recv(recvd.data(), recvd.size(), 3); |
| std::vector<uint8_t> expected{3}; |
| EXPECT_EQ(expected, recvd); |
| } |
| |
| // Ensure non-zero response status codes result in errors |
| TEST_F(MessageUSBMemberTest, RecvBadStatus) |
| { |
| MessageUSB::ReqHeader req; |
| req.type = MessageUSB::ReqHeader::kRead; |
| req.offset = 3; |
| req.length = 2; |
| std::vector<uint8_t> bulk(sizeof(req)); |
| memcpy(bulk.data(), &req, sizeof(req)); |
| EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _, |
| bulk.size(), _, 0)) |
| .WillOnce(DoAll(CompareBulk(bulk), Return(0))); |
| MessageUSB::RspHeader rsp; |
| rsp.status = 1; |
| bulk.resize(sizeof(rsp) + req.length); |
| memcpy(bulk.data(), &rsp, sizeof(rsp)); |
| EXPECT_CALL(libusb, |
| bulk_transfer(handle, ed[2].bEndpointAddress, _, |
| Ge(sizeof(MessageUSB::RspHeader) + req.length), _, |
| 0)) |
| .WillOnce(PopulateBulk(bulk)); |
| std::vector<uint8_t> recvd(3); |
| EXPECT_THROW(msg->recv(recvd.data(), recvd.size(), 3), std::runtime_error); |
| } |
| |
| // Ensure we can successfully receive chunked messages if the USB channel |
| // isn't wide enough to do it in a single transfer request. |
| TEST_F(MessageUSBMemberTest, RecvMulti) |
| { |
| testing::InSequence seq; |
| |
| MessageUSB::ReqHeader req; |
| req.type = MessageUSB::ReqHeader::kRead; |
| req.offset = 3; |
| req.length = 2; |
| std::vector<uint8_t> bulk(sizeof(req)); |
| memcpy(bulk.data(), &req, sizeof(req)); |
| EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _, |
| bulk.size(), _, 0)) |
| .WillOnce(DoAll(CompareBulk(bulk), Return(0))); |
| MessageUSB::RspHeader rsp; |
| rsp.status = 0; |
| bulk.resize(sizeof(rsp) + req.length); |
| memcpy(bulk.data(), &rsp, sizeof(rsp)); |
| bulk[sizeof(rsp)] = 2; |
| bulk[sizeof(rsp) + 1] = 1; |
| EXPECT_CALL(libusb, |
| bulk_transfer(handle, ed[2].bEndpointAddress, _, |
| Ge(sizeof(MessageUSB::RspHeader) + req.length), _, |
| 0)) |
| .WillOnce(PopulateBulk(bulk)); |
| |
| req.offset = 5; |
| req.length = 1; |
| bulk.resize(sizeof(req)); |
| memcpy(bulk.data(), &req, sizeof(req)); |
| EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _, |
| bulk.size(), _, 0)) |
| .WillOnce(DoAll(CompareBulk(bulk), Return(0))); |
| bulk.resize(sizeof(rsp) + req.length); |
| memcpy(bulk.data(), &rsp, sizeof(rsp)); |
| bulk[sizeof(rsp)] = 3; |
| EXPECT_CALL(libusb, |
| bulk_transfer(handle, ed[2].bEndpointAddress, _, |
| Ge(sizeof(MessageUSB::RspHeader) + req.length), _, |
| 0)) |
| .WillOnce(PopulateBulk(bulk)); |
| |
| std::vector<uint8_t> recvd(3); |
| msg->recv(recvd.data(), recvd.size(), 3); |
| std::vector<uint8_t> expected{2, 1, 3}; |
| EXPECT_EQ(expected, recvd); |
| } |
| |
| } // namespace hoth |
| } // namespace google |