// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "message_hoth_usb.hpp"

#include <stdplus/util/cexec.hpp>

#include <charconv>
#include <format>
#include <stdexcept>

// NOLINTNEXTLINE(cert-dcl58-cpp)
namespace google
{
namespace hoth
{
// Helper function to throw exceptions on returning system errors
int toErrno(int error)
{
    switch (error)
    {
        case LIBUSB_SUCCESS:
            return 0;
        case LIBUSB_ERROR_IO:
            return EIO;
        case LIBUSB_ERROR_INVALID_PARAM:
            return EINVAL;
        case LIBUSB_ERROR_ACCESS:
            return EACCES;
        case LIBUSB_ERROR_NO_DEVICE:
            return ENODEV;
        case LIBUSB_ERROR_NOT_FOUND:
            return ENOENT;
        case LIBUSB_ERROR_BUSY:
            return EBUSY;
        case LIBUSB_ERROR_TIMEOUT:
            return ETIMEDOUT;
        case LIBUSB_ERROR_OVERFLOW:
            return EOVERFLOW;
        case LIBUSB_ERROR_PIPE:
            return EPIPE;
        case LIBUSB_ERROR_INTERRUPTED:
            return EINTR;
        case LIBUSB_ERROR_NO_MEM:
            return ENOMEM;
        case LIBUSB_ERROR_NOT_SUPPORTED:
            return ENOTSUP;
        default:
            return ENOSYS;
    }
}

inline auto makeError(int error, const char* msg)
{
    return std::system_error(toErrno(-error), std::generic_category(), msg);
}

template <typename... Args>
inline auto callCheck(const char* msg, Args&&... args)
{
    return stdplus::util::callCheckRet<makeError, Args...>(
        msg, std::forward<Args>(args)...);
}

class TimeOut : public std::runtime_error
{
  public:
    explicit TimeOut(const std::string &msg) : runtime_error(msg) {}
};

// Helper function to throw exceptions on returning libhoth_usb errors.
template <class F, class... T>
void callThrow(F f, T... args)
{
    auto ret = f(std::forward<T>(args)...);
    switch (ret)
    {
        case LIBHOTH_OK:
            return;
        case LIBHOTH_ERR_UNKNOWN_VENDOR:
            throw std::runtime_error("hothusb error: unknown vendor");
        case LIBHOTH_ERR_INTERFACE_NOT_FOUND:
            throw std::runtime_error("hothusb error: interface not found");
        case LIBHOTH_ERR_MALLOC_FAILED:
            throw std::runtime_error("hothusb error: malloc failed");
        case LIBUSB_ERROR_TIMEOUT:
            throw TimeOut("hothusb error: timeout");
        case LIBHOTH_ERR_OUT_UNDERFLOW:
            throw std::runtime_error("hothusb error: out underflow");
        case LIBUSB_ERROR_OVERFLOW:
            throw std::runtime_error("hothusb error: in overflow");
        case LIBHOTH_ERR_UNSUPPORTED_VERSION:
            throw std::runtime_error("hothusb error: unsupported version");
        case LIBUSB_ERROR_BUSY:
        case LIBHOTH_ERR_INTERFACE_BUSY:
            throw std::runtime_error(
                "hothusb error: interface already claimed");
        default:
            throw std::runtime_error(
                std::format("hothusb error: unknown ({})", ret));
    }
}

// RAII wrapper around USB claim.
// Claim on initialization.
// Release on destruction.
class USBClaim
{
  public:
    explicit USBClaim(LibHothUsbIntf* usb, libhoth_device* device)
    {
        callThrow([usb, device] { return usb->claim(device); });
        usb_ = usb;
        device_ = device;
    }

    USBClaim(const USBClaim&) = delete;
    USBClaim& operator=(const USBClaim&) = delete;
    USBClaim(USBClaim&&) = delete;
    USBClaim& operator=(USBClaim&&) = delete;

    ~USBClaim()
    {
        usb_->release(device_);
    }

  private:
    LibHothUsbIntf* usb_ = nullptr;
    libhoth_device* device_ = nullptr;
};

// Helper function to parse USB id.
// The function will extract the numbers into Int type until the first
// `separator` and truncate. For example:
//      extract_separated_int<int>(str = "123.456")
//          -> return = (int)123, str = "456"
template <typename Int>
Int extract_separated_int(std::string_view& str, char separator)
{
    if (str.empty())
    {
        throw std::invalid_argument("Int str is empty");
    }

    Int ret;
    auto term = str.substr(0, str.find(separator));
    const auto* end = term.data() + term.size();
    auto res = std::from_chars(term.data(), end, ret);
    if (res.ec != std::errc() || res.ptr != end)
    {
        throw std::invalid_argument("Failed to parse int");
    }
    str.remove_prefix(term.size() + (term.size() == str.size() ? 0 : 1));
    return ret;
}

// helper function to find usb device from usb id.
libusb_device* find_dev(LibusbIntf* libusb, libusb_context* ctx,
                        std::string_view usb_id)
{
    if (libusb == nullptr)
    {
        throw std::runtime_error("libusb is nullptr");
    }
    auto bus_id = extract_separated_int<uint8_t>(usb_id, '-');
    std::vector<uint8_t> port_ids;
    do
    {
        port_ids.push_back(extract_separated_int<uint8_t>(usb_id, '.'));
    } while (!usb_id.empty());

    libusb_device** devices;
    auto dev_count = callCheck("get_device_list", &LibusbIntf::get_device_list,
                               libusb, ctx, &devices);

    for (int count = 0; count < dev_count; count++)
    {
        libusb_device* dev = devices[count];
        if (bus_id == libusb->get_bus_number(dev))
        {
            // As per the USB 3.0 specs, the current maximum limit for the depth
            // is 7
            uint8_t port_numbers[7];
            size_t r = 0;
            try
            {
                r = callCheck("get_port_numbers", &LibusbIntf::get_port_numbers,
                              libusb, dev, port_numbers, sizeof(port_numbers));
            }
            catch (const std::exception& e)
            {
                libusb->free_device_list(devices, 0);
                throw;
            }
            if (r != port_ids.size())
            {
                continue;
            }
            unsigned i = 0;
            for (; i < r; i++)
            {
                if (port_numbers[i] != port_ids[i])
                {
                    break;
                }
            }
            if (i == r)
            {
                libusb->ref_device(dev);
                libusb->free_device_list(devices, 1);
                return dev;
            }
        }
    }
    libusb->free_device_list(devices, 1);
    throw std::runtime_error("Failed to find usb device");
    return nullptr;
}

MessageHothUSB::MessageHothUSB(std::string_view usb_id, LibusbIntf* libusb,
                               LibHothUsbIntf* libhoth_usb) :
    libusb_(libusb),
    ctx_([this]() {
        libusb_context* context;
        libusb_->init(&context);
        return context;
    }()),
    dev_([this](std::string_view usb_id) {
        libusb_device* device = nullptr;
        try
        {
            device = find_dev(libusb_, ctx_, usb_id);
        }
        catch (const std::exception& e)
        {
            libusb_->exit(ctx_);
            throw e;
        }

        return device;
    }(usb_id)),
    libhoth_usb_(libhoth_usb), hoth_dev_([this]() {
        libhoth_device* dev;
        libhoth_usb_device_init_options option{dev_, ctx_};
        try
        {
            callThrow(
                [this, option = &option, dev = &dev] { return libhoth_usb_->open(option, dev); });

            // Don't indefinitely hold the USB interface.
            libhoth_usb_->release(dev);
        }
        catch (const std::exception& e)
        {
            libusb_->unref_device(dev_);
            libusb_->exit(ctx_);

            // If we failed to open the device, then this class can't
            // meaningfully use it. Let the caller decide what to do.
            throw;
        }

        return dev;
    }())
{}

MessageHothUSB::MessageHothUSB(MessageHothUSB &&other) noexcept
    : libusb_(nullptr), ctx_(nullptr), dev_(nullptr), libhoth_usb_(nullptr),
      hoth_dev_(nullptr)
{
    this->swap(other);
}

MessageHothUSB &MessageHothUSB::operator=(MessageHothUSB &&other) noexcept
{
    MessageHothUSB temp(std::move(other));
    this->swap(temp);
    return *this;
}

MessageHothUSB::~MessageHothUSB()
{
    if (libhoth_usb_ && hoth_dev_)
    {
        libhoth_usb_->close(hoth_dev_);
    }
    if (libusb_ && dev_)
    {
        libusb_->unref_device(dev_);
    }
    if (libusb_ && ctx_)
    {
        libusb_->exit(ctx_);
    }
}

void MessageHothUSB::send(const uint8_t* buf, size_t size,
                          [[maybe_unused]] size_t seek)
{
    USBClaim usb_claim(libhoth_usb_, hoth_dev_);

    callThrow(
        [this, buf, size] { return libhoth_usb_->send(hoth_dev_, buf, size); });

    // libhoth_usb read/write is non addressible, we need to buffer the
    // response.
    buff_.resize(kMailboxSize);
    size_t read = 0;

    // Try infinitly on timeout.
    while (true)
    {
        try
        {
            callThrow([this, data = buff_.data(), read = &read] { return libhoth_usb_->recv(hoth_dev_, data, kMailboxSize, read, timeout); });
            break;
        }
        catch (const TimeOut& e)
        {
            // continue on timeout.
            continue;
        }
    }
    buff_.resize(read);
}

void MessageHothUSB::recv(uint8_t* buf, size_t size, size_t seek)
{
    if (size + seek > kMailboxSize)
    {
        throw std::runtime_error("Memory out of boundary");
    }

    memcpy(buf, buff_.data() + seek, size);
}

} // namespace hoth
} // namespace google
