blob: 0c435c9eb914cb615219594b0060ee10742fc432 [file] [log] [blame]
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "usb_util.hpp"
#include <algorithm>
#include <charconv>
#include <span>
#include <stdexcept>
namespace google
{
namespace hoth
{
template <typename Int>
Int extract_separated_int(std::string_view& str, char separator)
{
if (str.empty())
{
throw std::invalid_argument("Int str is empty");
}
Int ret;
auto term = str.substr(0, str.find(separator));
const auto *end = term.data() + term.size();
auto res = std::from_chars(term.data(), end, ret);
if (static_cast<bool>(res.ec) || res.ptr != end)
{
throw std::invalid_argument("Failed to parse int");
}
str.remove_prefix(term.size() + (term.size() == str.size() ? 0 : 1));
return ret;
}
libusb::Device find_dev(libusb::Context& ctx, std::string_view usb_id)
{
auto bus_id = extract_separated_int<uint8_t>(usb_id, '-');
std::vector<uint8_t> port_ids;
do
{
port_ids.push_back(extract_separated_int<uint8_t>(usb_id, '.'));
} while (!usb_id.empty());
for (auto& dev : ctx.get_device_list())
{
if (bus_id == dev.get_bus_number() &&
dev.get_port_numbers() == port_ids)
{
return std::move(dev);
}
}
throw std::runtime_error("Failed to find usb device");
}
const libusb_interface_descriptor&
find_intf(libusb::Device& dev, uint8_t usbClass,
std::span<const uint8_t> usbSubclasses, uint8_t usbProtocol)
{
const auto& config = dev.get_active_config_descriptor();
for (uint8_t i = 0; i < config.bNumInterfaces; ++i)
{
const auto& interface = config.interface[i];
for (int j = 0; j < interface.num_altsetting; ++j)
{
const auto& setting = interface.altsetting[j];
if (setting.bInterfaceClass == usbClass &&
(std::count(usbSubclasses.begin(), usbSubclasses.end(),
setting.bInterfaceSubClass) > 0) &&
setting.bInterfaceProtocol == usbProtocol)
{
return setting;
}
}
}
throw std::runtime_error("Failed to find hoth mailbox interface");
}
const libusb_endpoint_descriptor&
find_ep(const libusb_interface_descriptor& interface, uint8_t dir)
{
for (uint8_t i = 0; i < interface.bNumEndpoints; ++i)
{
const auto& endpoint = interface.endpoint[i];
if ((endpoint.bmAttributes & libusb::kTransferTypeBits) ==
LIBUSB_TRANSFER_TYPE_BULK &&
(endpoint.bEndpointAddress & libusb::kAddressDirectionBits) == dir)
{
return endpoint;
}
}
throw std::runtime_error("Failed to find hoth transfer endpoint");
}
} // namespace hoth
} // namespace google