blob: 02681a13ba2f1e7eb3a9a1e53c22a8d07a7f28fe [file] [log] [blame]
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "message_usb.hpp"
#include "usb_util.hpp"
#include <algorithm>
#include <chrono>
#include <cstdint>
#include <format>
#include <stdexcept>
#include <thread>
#include <utility>
#include <vector>
using namespace std::chrono_literals;
namespace google
{
namespace hoth
{
MessageUSB::MessageUSB(std::string_view usb_id, LibusbIntf* libusb,
bool unit_test) :
ctx(libusb),
dev(find_dev(ctx, usb_id)),
interface(
find_intf(dev, kMailboxClass, kMailboxSubClasses, kMailboxProtocol)),
in(find_ep(interface, LIBUSB_ENDPOINT_IN)),
out(find_ep(interface, LIBUSB_ENDPOINT_OUT)),
maxInSize(std::min(in.wMaxPacketSize, kMaxBulkTransferSize)),
maxOutSize(std::min(out.wMaxPacketSize, kMaxBulkTransferSize)),
handle(dev.open()),
claim(handle.claim_interface(interface.bInterfaceNumber))
{
if (maxOutSize < sizeof(ReqHeader) + 1)
{
throw std::invalid_argument("Output packet size too small");
}
if (maxInSize < sizeof(RspHeader) + 1)
{
throw std::invalid_argument("Input packet size too small");
}
// Workaround for hoth producing bad responses if transfers
// start too quickly after initialization.
if (!unit_test)
{
using namespace std::literals::chrono_literals;
std::this_thread::sleep_for(1s);
}
}
void MessageUSB::send(const uint8_t* buf, size_t size, size_t seek)
{
std::array<uint8_t, kMaxBulkTransferSize> data;
ReqHeader req;
req.type = ReqHeader::kWrite;
req.offset = seek;
while (size > 0)
{
req.length = std::min(size, maxOutSize - sizeof(ReqHeader));
memcpy(data.data(), &req, sizeof(req));
memcpy(data.data() + sizeof(req), buf, req.length);
size_t sent = req.length + sizeof(req);
if (handle.bulk_transfer(out, data.data(), sent, 0ms) != sent)
{
throw std::runtime_error("Failed to submit entire write request");
}
RspHeader rsp;
size_t recvd = handle.bulk_transfer(in, data.data(), data.size(), 0ms);
if (recvd < sizeof(rsp))
{
throw std::runtime_error("Failed to receive write status");
}
memcpy(&rsp, data.data(), sizeof(rsp));
if (rsp.status != 0)
{
throw std::runtime_error(
std::format("Mailbox returned an error on write {:02x}",
static_cast<uint8_t>(rsp.status)));
}
buf += req.length;
size -= req.length;
req.offset += req.length;
}
}
void MessageUSB::recv(uint8_t* buf, size_t size, size_t seek)
{
std::array<uint8_t, kMaxBulkTransferSize> data;
ReqHeader req;
req.type = ReqHeader::kRead;
req.offset = seek;
while (size > 0)
{
req.length = std::min(size, maxInSize - sizeof(RspHeader));
memcpy(data.data(), &req, sizeof(req));
if (handle.bulk_transfer(out, data.data(), sizeof(req), 0ms) !=
sizeof(req))
{
throw std::runtime_error("Failed to send entire read request");
}
size_t recvd = handle.bulk_transfer(in, data.data(), data.size(), 0ms);
if (recvd < sizeof(RspHeader))
{
throw std::runtime_error("Failed to receive header");
}
RspHeader rsp;
memcpy(&rsp, data.data(), sizeof(rsp));
if (rsp.status != 0)
{
throw std::runtime_error(
std::format("Mailbox returned an error on read {:02x}",
static_cast<uint8_t>(rsp.status)));
}
recvd = std::min(size, recvd - sizeof(rsp));
memcpy(buf, data.data() + sizeof(rsp), recvd);
buf += recvd;
size -= recvd;
req.offset += recvd;
}
}
} // namespace hoth
} // namespace google