// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "libusb_mock.hpp"
#include "message_usb.hpp"

#include <algorithm>
#include <memory>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>

namespace google
{
namespace hoth
{

using ::testing::_;
using ::testing::AnyNumber;
using ::testing::DoAll;
using ::testing::Ge;
using ::testing::Return;
using ::testing::SetArgPointee;

ACTION_P(CopyNumbers, nums)
{
    const auto c = static_cast<decltype(arg2)>(nums.size());
    if (arg2 < c)
    {
        return LIBUSB_ERROR_OVERFLOW;
    }
    const auto copied = std::min(arg2, c);
    if (nums.data() != nullptr) {
        memcpy(arg1, nums.data(), copied);
    }
    return copied;
}

class MessageUSBTest : public ::testing::Test
{
  protected:
    MessageUSBTest()
    {
        EXPECT_CALL(libusb, init(_))
            .WillRepeatedly(DoAll(SetArgPointee<0>(ctx), Return(0)));
        EXPECT_CALL(libusb, exit(ctx)).Times(AnyNumber());
        EXPECT_CALL(libusb, get_device_list(ctx, _))
            .WillRepeatedly([&](libusb_context*, libusb_device*** list) {
                *list = devices.data();
                return 0;
            });
        EXPECT_CALL(libusb, free_device_list(_, 0)).Times(AnyNumber());
    }

    void add_device(int devn, int bus, const std::vector<uint8_t>& ports)
    {
        // NOLINTNEXTLINE(performance-no-int-to-ptr)
        libusb_device* dev = reinterpret_cast<libusb_device*>(devn);
        EXPECT_CALL(libusb, get_bus_number(dev)).WillRepeatedly(Return(bus));
        EXPECT_CALL(libusb, get_port_numbers(dev, _, _))
            .WillRepeatedly(CopyNumbers(ports));
        EXPECT_CALL(libusb, unref_device(dev)).Times(AnyNumber());
        devices[devices.size() - 1] = dev;
        devices.push_back(nullptr);
    }

    std::vector<libusb_device*> devices = {nullptr};
    ::testing::StrictMock<LibusbMock> libusb;
    libusb_context* const ctx = reinterpret_cast<libusb_context*>(1);
};

TEST_F(MessageUSBTest, ParseDeviceString)
{
    // Depends on distinguishing invalid bus+port strings from valid ones
    // based on the MessageUSB throwing invalid_argument for bad input names.
    EXPECT_THROW(MessageUSB("", &libusb), std::invalid_argument);
    EXPECT_THROW(MessageUSB("1", &libusb), std::invalid_argument);
    EXPECT_THROW(MessageUSB("a-1", &libusb), std::invalid_argument);
    EXPECT_THROW(MessageUSB("1-a", &libusb), std::invalid_argument);
    EXPECT_THROW(MessageUSB("1-.1", &libusb), std::invalid_argument);
    EXPECT_THROW(MessageUSB("1-1.3-", &libusb), std::invalid_argument);

    EXPECT_THROW(MessageUSB("1-3", &libusb), std::runtime_error);
    EXPECT_THROW(MessageUSB("1-1.3", &libusb), std::runtime_error);
}

TEST_F(MessageUSBTest, MatchDevice)
{
    add_device(2, 1, {});
    add_device(3, 1, {1, 2});
    add_device(4, 1, {2, 1});

    // Depends on the mock never being asked for the active config_descriptor
    // when the device could not be found. Only once it finds a valid device
    // during lookup will it check the config_descriptor.
    EXPECT_THROW(MessageUSB("1-3", &libusb), std::runtime_error);
    EXPECT_THROW(MessageUSB("1-1", &libusb), std::runtime_error);
    EXPECT_THROW(MessageUSB("1-1.1", &libusb), std::runtime_error);
    EXPECT_THROW(MessageUSB("1-1.2.1", &libusb), std::runtime_error);

    libusb_config_descriptor cd;
    cd.bNumInterfaces = 0;
    EXPECT_CALL(libusb, get_active_config_descriptor(
                            reinterpret_cast<libusb_device*>(3), _))
        .WillOnce(DoAll(SetArgPointee<1>(&cd), Return(0)));
    EXPECT_THROW(MessageUSB("1-1.2", &libusb), std::runtime_error);
}

class MessageUSBMemberTest : public MessageUSBTest
{
  protected:
    MessageUSBMemberTest()
    {
        // ed[0] is intentionally a non BULK type transfer endpoint
        // to ensure we filter correctly.
        ed[0].bmAttributes = LIBUSB_TRANSFER_TYPE_CONTROL;
        ed[1].bmAttributes = LIBUSB_TRANSFER_TYPE_BULK;
        ed[1].bEndpointAddress = 0x03 | LIBUSB_ENDPOINT_OUT;
        ed[1].wMaxPacketSize = sizeof(MessageUSB::ReqHeader) + 2;
        ed[2].bmAttributes = LIBUSB_TRANSFER_TYPE_BULK;
        ed[2].bEndpointAddress = 0x03 | LIBUSB_ENDPOINT_IN;
        ed[2].wMaxPacketSize = sizeof(MessageUSB::RspHeader) + 2;

        // id[0] is intentionally not a hoth mailnbox to ensure we filter
        // for the correct interface.
        id[0].bInterfaceClass = MessageUSB::kMailboxClass;
        id[0].bInterfaceSubClass = MessageUSB::kMailboxSubClasses[1] - 1;
        id[0].bInterfaceProtocol = MessageUSB::kMailboxProtocol;
        id[1].endpoint = ed;
        id[1].bNumEndpoints = 3;
        id[1].bInterfaceClass = MessageUSB::kMailboxClass;
        id[1].bInterfaceSubClass = MessageUSB::kMailboxSubClasses[1];
        id[1].bInterfaceProtocol = MessageUSB::kMailboxProtocol;
        id[1].bInterfaceNumber = 5;

        // intf[0] is intentionally empty to ensure we are looping over these
        intf[0].num_altsetting = 0;
        intf[1].altsetting = id;
        intf[1].num_altsetting = 2;
        cd.interface = intf;
        cd.bNumInterfaces = 2;

        add_device(2, 1, {1, 1});
        EXPECT_CALL(libusb, get_active_config_descriptor(devices[0], _))
            .WillOnce(DoAll(SetArgPointee<1>(&cd), Return(0)));
        EXPECT_CALL(libusb, open(devices[0], _))
            .WillOnce(DoAll(SetArgPointee<1>(handle), Return(0)));
        EXPECT_CALL(libusb, close(handle)).Times(1);
        EXPECT_CALL(libusb, claim_interface(handle, id[1].bInterfaceNumber))
            .WillOnce(Return(0));
        EXPECT_CALL(libusb, release_interface(handle, id[1].bInterfaceNumber))
            .WillOnce(Return(0));

        msg = std::make_unique<MessageUSB>("1-1.1", &libusb, /*unit_test=*/true);
    }

    libusb_device_handle* const handle =
        reinterpret_cast<libusb_device_handle*>(10);
    libusb_endpoint_descriptor ed[3];
    libusb_config_descriptor cd;
    libusb_interface intf[2];
    libusb_interface_descriptor id[2];
    std::unique_ptr<MessageUSB> msg = nullptr;
};

TEST_F(MessageUSBMemberTest, CompleteOpen)
{}

ACTION_P(CompareBulk, bulk)
{
    EXPECT_EQ(arg3, bulk.size());
    EXPECT_EQ(0, memcmp(bulk.data(), arg2, bulk.size()));
    *arg4 = bulk.size();
}

ACTION_P(PopulateBulk, out)
{
    if (arg3 < static_cast<decltype(arg3)>(out.size()))
    {
        throw std::invalid_argument("populate");
    }
    memcpy(arg2, out.data(), out.size());
    *arg4 = out.size();
    return 0;
}

// Ensure truncated request transfers result in errors
TEST_F(MessageUSBMemberTest, SendPartialOut)
{
    std::vector<uint8_t> buf = {1, 2, 3, 4, 5};
    MessageUSB::ReqHeader req;
    req.type = MessageUSB::ReqHeader::kWrite;
    req.offset = 3;
    req.length = 2;
    std::vector<uint8_t> bulk(sizeof(req) + req.length);
    memcpy(bulk.data(), &req, sizeof(req));
    memcpy(bulk.data() + sizeof(req), buf.data(), req.length);
    EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _,
                                      bulk.size(), _, 0))
        .WillOnce(DoAll(CompareBulk(bulk), SetArgPointee<4>(bulk.size() - 1),
                        Return(0)));
    EXPECT_THROW(msg->send(buf.data(), buf.size(), 3), std::runtime_error);
}

// Ensure truncated response transfers result in errors
TEST_F(MessageUSBMemberTest, SendPartialIn)
{
    std::vector<uint8_t> buf = {1, 2, 3, 4, 5};
    MessageUSB::ReqHeader req;
    req.type = MessageUSB::ReqHeader::kWrite;
    req.offset = 3;
    req.length = 2;
    std::vector<uint8_t> bulk(sizeof(req) + req.length);
    memcpy(bulk.data(), &req, sizeof(req));
    memcpy(bulk.data() + sizeof(req), buf.data(), req.length);
    EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _,
                                      bulk.size(), _, 0))
        .WillOnce(DoAll(CompareBulk(bulk), Return(0)));
    EXPECT_CALL(libusb, bulk_transfer(handle, ed[2].bEndpointAddress, _,
                                      Ge(sizeof(MessageUSB::RspHeader)), _, 0))
        .WillOnce(DoAll(SetArgPointee<4>(sizeof(MessageUSB::RspHeader) - 1),
                        Return(0)));
    EXPECT_THROW(msg->send(buf.data(), buf.size(), 3), std::runtime_error);
}

// Ensure non-zero response status codes result in errors
TEST_F(MessageUSBMemberTest, SendBadStatus)
{
    std::vector<uint8_t> buf = {1, 2, 3, 4, 5};
    MessageUSB::ReqHeader req;
    req.type = MessageUSB::ReqHeader::kWrite;
    req.offset = 3;
    req.length = 2;
    std::vector<uint8_t> bulk(sizeof(req) + req.length);
    memcpy(bulk.data(), &req, sizeof(req));
    memcpy(bulk.data() + sizeof(req), buf.data(), req.length);
    EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _,
                                      bulk.size(), _, 0))
        .WillOnce(DoAll(CompareBulk(bulk), Return(0)));
    MessageUSB::RspHeader rsp;
    rsp.status = 1;
    bulk.resize(sizeof(rsp));
    memcpy(bulk.data(), &rsp, sizeof(rsp));
    EXPECT_CALL(libusb, bulk_transfer(handle, ed[2].bEndpointAddress, _,
                                      Ge(sizeof(rsp)), _, 0))
        .WillOnce(PopulateBulk(bulk));
    EXPECT_THROW(msg->send(buf.data(), buf.size(), 3), std::runtime_error);
}

// Ensure responses that are too long get truncated
TEST_F(MessageUSBMemberTest, SendLongRsp)
{
    std::vector<uint8_t> buf = {1};
    MessageUSB::ReqHeader req;
    req.type = MessageUSB::ReqHeader::kWrite;
    req.offset = 3;
    req.length = buf.size();
    std::vector<uint8_t> bulk(sizeof(req) + req.length);
    memcpy(bulk.data(), &req, sizeof(req));
    memcpy(bulk.data() + sizeof(req), buf.data(), req.length);
    EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _,
                                      bulk.size(), _, 0))
        .WillOnce(DoAll(CompareBulk(bulk), Return(0)));
    MessageUSB::RspHeader rsp;
    rsp.status = 0;
    memcpy(bulk.data(), &rsp, sizeof(rsp));
    EXPECT_CALL(libusb, bulk_transfer(handle, ed[2].bEndpointAddress, _,
                                      Ge(bulk.size()), _, 0))
        .WillOnce(PopulateBulk(bulk));
    msg->send(buf.data(), buf.size(), 3);
}

// Ensure we can successfully send a chunked message if the USB channel
// isn't wide enough to do it in a single transfer request
TEST_F(MessageUSBMemberTest, SendMulti)
{
    std::vector<uint8_t> buf = {1, 2, 3};
    MessageUSB::ReqHeader req;
    req.type = MessageUSB::ReqHeader::kWrite;
    req.offset = 3;
    req.length = 2;
    std::vector<uint8_t> bulk(sizeof(req) + req.length);
    memcpy(bulk.data(), &req, sizeof(req));
    bulk[sizeof(req)] = buf[0];
    bulk[sizeof(req) + 1] = buf[1];
    EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _,
                                      bulk.size(), _, 0))
        .WillOnce(DoAll(CompareBulk(bulk), Return(0)));
    MessageUSB::RspHeader rsp;
    rsp.status = 0;
    bulk.resize(sizeof(rsp));
    memcpy(bulk.data(), &rsp, sizeof(rsp));
    EXPECT_CALL(libusb, bulk_transfer(handle, ed[2].bEndpointAddress, _,
                                      Ge(sizeof(rsp)), _, 0))
        .WillRepeatedly(PopulateBulk(bulk));

    req.offset = 5;
    req.length = 1;
    bulk.resize(sizeof(req) + req.length);
    memcpy(bulk.data(), &req, sizeof(req));
    bulk[sizeof(req)] = buf[2];
    EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _,
                                      bulk.size(), _, 0))
        .WillOnce(DoAll(CompareBulk(bulk), Return(0)));

    msg->send(buf.data(), buf.size(), 3);
}

// Ensure truncated request transfers result in errors
TEST_F(MessageUSBMemberTest, RecvPartialOut)
{
    MessageUSB::ReqHeader req;
    req.type = MessageUSB::ReqHeader::kRead;
    req.offset = 3;
    req.length = 2;
    std::vector<uint8_t> bulk(sizeof(req));
    memcpy(bulk.data(), &req, sizeof(req));
    EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _,
                                      bulk.size(), _, 0))
        .WillOnce(DoAll(CompareBulk(bulk), SetArgPointee<4>(bulk.size() - 1),
                        Return(0)));
    std::vector<uint8_t> recvd(3);
    EXPECT_THROW(msg->recv(recvd.data(), recvd.size(), 3), std::runtime_error);
}

// Ensure truncated response transfers result in errors
TEST_F(MessageUSBMemberTest, RecvPartialIn)
{
    MessageUSB::ReqHeader req;
    req.type = MessageUSB::ReqHeader::kRead;
    req.offset = 3;
    req.length = 2;
    std::vector<uint8_t> bulk(sizeof(req));
    memcpy(bulk.data(), &req, sizeof(req));
    EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _,
                                      bulk.size(), _, 0))
        .WillOnce(DoAll(CompareBulk(bulk), Return(0)));
    EXPECT_CALL(libusb,
                bulk_transfer(handle, ed[2].bEndpointAddress, _,
                              Ge(sizeof(MessageUSB::RspHeader) + req.length), _,
                              0))
        .WillOnce(DoAll(SetArgPointee<4>(sizeof(MessageUSB::RspHeader) - 1),
                        Return(0)));
    std::vector<uint8_t> recvd(3);
    EXPECT_THROW(msg->recv(recvd.data(), recvd.size(), 3), std::runtime_error);
}

// Ensure extra response data is discarded
TEST_F(MessageUSBMemberTest, RecvLongRsp)
{
    MessageUSB::ReqHeader req;
    req.type = MessageUSB::ReqHeader::kRead;
    req.offset = 3;
    req.length = 1;
    std::vector<uint8_t> bulk(sizeof(req));
    memcpy(bulk.data(), &req, sizeof(req));
    EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _,
                                      bulk.size(), _, 0))
        .WillOnce(DoAll(CompareBulk(bulk), Return(0)));
    MessageUSB::RspHeader rsp;
    rsp.status = 0;
    bulk.resize(sizeof(rsp) + req.length + 1);
    memcpy(bulk.data(), &rsp, sizeof(rsp));
    bulk[sizeof(rsp)] = 3;
    bulk[sizeof(rsp) + 1] = 4;
    EXPECT_CALL(libusb,
                bulk_transfer(handle, ed[2].bEndpointAddress, _,
                              Ge(sizeof(MessageUSB::RspHeader) + req.length), _,
                              0))
        .WillOnce(PopulateBulk(bulk));
    std::vector<uint8_t> recvd(1);
    msg->recv(recvd.data(), recvd.size(), 3);
    std::vector<uint8_t> expected{3};
    EXPECT_EQ(expected, recvd);
}

// Ensure non-zero response status codes result in errors
TEST_F(MessageUSBMemberTest, RecvBadStatus)
{
    MessageUSB::ReqHeader req;
    req.type = MessageUSB::ReqHeader::kRead;
    req.offset = 3;
    req.length = 2;
    std::vector<uint8_t> bulk(sizeof(req));
    memcpy(bulk.data(), &req, sizeof(req));
    EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _,
                                      bulk.size(), _, 0))
        .WillOnce(DoAll(CompareBulk(bulk), Return(0)));
    MessageUSB::RspHeader rsp;
    rsp.status = 1;
    bulk.resize(sizeof(rsp) + req.length);
    memcpy(bulk.data(), &rsp, sizeof(rsp));
    EXPECT_CALL(libusb,
                bulk_transfer(handle, ed[2].bEndpointAddress, _,
                              Ge(sizeof(MessageUSB::RspHeader) + req.length), _,
                              0))
        .WillOnce(PopulateBulk(bulk));
    std::vector<uint8_t> recvd(3);
    EXPECT_THROW(msg->recv(recvd.data(), recvd.size(), 3), std::runtime_error);
}

// Ensure we can successfully receive chunked messages if the USB channel
// isn't wide enough to do it in a single transfer request.
TEST_F(MessageUSBMemberTest, RecvMulti)
{
    testing::InSequence seq;

    MessageUSB::ReqHeader req;
    req.type = MessageUSB::ReqHeader::kRead;
    req.offset = 3;
    req.length = 2;
    std::vector<uint8_t> bulk(sizeof(req));
    memcpy(bulk.data(), &req, sizeof(req));
    EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _,
                                      bulk.size(), _, 0))
        .WillOnce(DoAll(CompareBulk(bulk), Return(0)));
    MessageUSB::RspHeader rsp;
    rsp.status = 0;
    bulk.resize(sizeof(rsp) + req.length);
    memcpy(bulk.data(), &rsp, sizeof(rsp));
    bulk[sizeof(rsp)] = 2;
    bulk[sizeof(rsp) + 1] = 1;
    EXPECT_CALL(libusb,
                bulk_transfer(handle, ed[2].bEndpointAddress, _,
                              Ge(sizeof(MessageUSB::RspHeader) + req.length), _,
                              0))
        .WillOnce(PopulateBulk(bulk));

    req.offset = 5;
    req.length = 1;
    bulk.resize(sizeof(req));
    memcpy(bulk.data(), &req, sizeof(req));
    EXPECT_CALL(libusb, bulk_transfer(handle, ed[1].bEndpointAddress, _,
                                      bulk.size(), _, 0))
        .WillOnce(DoAll(CompareBulk(bulk), Return(0)));
    bulk.resize(sizeof(rsp) + req.length);
    memcpy(bulk.data(), &rsp, sizeof(rsp));
    bulk[sizeof(rsp)] = 3;
    EXPECT_CALL(libusb,
                bulk_transfer(handle, ed[2].bEndpointAddress, _,
                              Ge(sizeof(MessageUSB::RspHeader) + req.length), _,
                              0))
        .WillOnce(PopulateBulk(bulk));

    std::vector<uint8_t> recvd(3);
    msg->recv(recvd.data(), recvd.size(), 3);
    std::vector<uint8_t> expected{2, 1, 3};
    EXPECT_EQ(expected, recvd);
}

} // namespace hoth
} // namespace google
