blob: 60d9e28f3d32d49c7a7d4d733037aeb5c439739c [file] [log] [blame]
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "hoth.hpp"
#include "google3/ec_commands.h"
#include "google3/host_commands.h"
#include "ec_util_interface.hpp"
#include "firmware_mtd_updater.hpp"
#include "message_util.hpp"
#include "payload_update_interface.hpp"
#include <stdplus/print.hpp>
#include <stdplus/raw.hpp>
#include <xyz/openbmc_project/Control/Hoth/error.hpp>
#include <array>
#include <chrono>
#include <cstddef>
#include <cstdint>
#include <exception>
#include <future>
#include <span>
#include <string>
#include <utility>
#include <vector>
namespace google
{
namespace hoth
{
namespace
{
// Command to be filtered
// Value should be consistent to security/crypta/firmware/app/commands/
// Ideally the crypta command should be moved to some common location so
// we don't need to hardcode them here
// major commands
constexpr uint8_t kCryptaProdidMajorCommand = 0x05;
constexpr uint8_t kCryptaCimMajorCommand = 0x01;
// minor commands
constexpr uint8_t kCryptaProdidBeginUnwrapType1BlobMinorCommand = 0x00;
constexpr uint8_t kCryptaProdidLoadTokenMinorCommand = 0x05;
constexpr uint8_t kCryptaCimLoadCbkMinorCommand = 0x0A;
constexpr uint8_t kCryptaCimLoadCbkDataMinorCommand = 0x0B;
constexpr uint8_t kCryptaCimLockCbkSlotMinorCommand = 0x15;
constexpr uint8_t kCryptaCimIncrementCbkCounterCommand = 0x17;
constexpr uint8_t kCryptaHostCommandVersion = 0x00;
} // namespace
using sdbusplus::error::xyz::openbmc_project::control::hoth::CommandFailure;
using sdbusplus::error::xyz::openbmc_project::control::hoth::ExpectedInfoNotFound;
using sdbusplus::error::xyz::openbmc_project::control::hoth::FirmwareFailure;
using sdbusplus::error::xyz::openbmc_project::control::hoth::InterfaceError;
using sdbusplus::error::xyz::openbmc_project::control::hoth::ResponseFailure;
using internal::CryptaHeader;
using internal::CBKSlotCmd;
using internal::EC_RES_ACCESS_DENIED;
using internal::FirmwareUpdater;
using internal::LoadTokenCmd;
using internal::PayloadUpdate;
using internal::ReqHeader;
using internal::SUPPORTED_STRUCT_VERSION;
Hoth::FirmwareUpdateStatus Hoth::getAsyncStatus(std::future<void>* asyncFuture)
{
std::future_status asyncStatus =
asyncFuture->wait_for(std::chrono::seconds(0));
if (asyncStatus == std::future_status::timeout)
{
return Hoth::FirmwareUpdateStatus::InProgress;
}
// std::future_status can only be "timeout", "ready" and "deferred".
// Our async threads should never be "deferred" so they must be "ready"
// if they managed to reach here.
try
{
asyncFuture->get();
// If we did not throw an exception while doing a get, async thread must
// have successfully completed and exited
return Hoth::FirmwareUpdateStatus::Done;
}
catch (const std::exception& e)
{
stdplus::print(stderr, "The polled async thread had an exception: {}\n",
e.what());
return Hoth::FirmwareUpdateStatus::Error;
}
}
void Hoth::checkForOngoingPayload()
{
// If any of the payload future objects are valid, it means there is an
// async payload process in the background.
if (futInitiatePayload.valid())
{
stdplus::print(stderr,
"There is an initiate payload command in progress\n");
}
else if (futSendPayload.valid())
{
stdplus::print(stderr, "There is a send payload command in progress\n");
}
else if (futVerifyPayload.valid())
{
stdplus::print(stderr,
"There is a verify payload command in progress\n");
}
else
{
// No ongoing payload async thread detected
return;
}
throw InterfaceError();
}
static bool isCryptaCommand(std::span<const uint8_t> command)
{
try
{
auto& ecHeader = stdplus::raw::refFrom<ReqHeader>(command);
// Struct version doesn't match
if (ecHeader.struct_version != SUPPORTED_STRUCT_VERSION)
{
return false;
}
// Not a Crypta Command
if (ecHeader.command !=
EC_CMD_BOARD_SPECIFIC_BASE + EC_PRV_CMD_HOTH_CRYPTA)
{
return false;
}
// Command version doesn't match
if (ecHeader.command_version != kCryptaHostCommandVersion)
{
return false;
}
}
catch (...)
{
// Command is too short, cannot be a crypta command
return false;
}
return true;
}
static bool isSensitiveSlotSpecificCommand(const CryptaHeader& hdr)
{
if (hdr.major == kCryptaProdidMajorCommand &&
hdr.minor == kCryptaProdidBeginUnwrapType1BlobMinorCommand)
{
return true;
}
if (hdr.major == kCryptaCimMajorCommand &&
(hdr.minor == kCryptaCimLoadCbkMinorCommand ||
hdr.minor == kCryptaCimLoadCbkDataMinorCommand ||
hdr.minor == kCryptaCimLockCbkSlotMinorCommand ||
hdr.minor == kCryptaCimIncrementCbkCounterCommand))
{
return true;
}
return false;
}
static bool isLoadTokensCommand(std::span<const uint8_t> command)
{
try
{
auto& loadTokenCmd = stdplus::raw::refFrom<LoadTokenCmd>(command);
return (loadTokenCmd.cryptaHdr.major == kCryptaProdidMajorCommand) &&
(loadTokenCmd.cryptaHdr.minor ==
kCryptaProdidLoadTokenMinorCommand);
}
catch (...)
{
// command too short to be a load token command
}
return false;
}
static bool isCBKSlot2Command(std::span<const uint8_t> command)
{
try
{
auto& cbkSlotCmd = stdplus::raw::refFrom<CBKSlotCmd>(command);
if (isSensitiveSlotSpecificCommand(cbkSlotCmd.cryptaHdr))
{
// all commands here has slot id as the 1st param
// param count should be at least 1 for this msg, but here we won't
// guard the case, crypta will drop the message anyway
if (cbkSlotCmd.cryptaHdr.count == 0)
{
return false;
}
// slotid size, ideally should be 4, again we don't guard it here
if (cbkSlotCmd.cryptaParam.size != 4)
{
return false;
}
// check if slotid is 2
if (cbkSlotCmd.slotId == 2)
{
return true;
}
}
}
catch (...)
{
// command too short to be a cbk slot 2 command
}
return false;
}
static bool isHostForbiddenCommand(std::span<const uint8_t> command)
{
if (!isCryptaCommand(command))
{
return false;
}
return isLoadTokensCommand(command) || isCBKSlot2Command(command);
}
std::vector<uint8_t> Hoth::sendHostCommand(const std::vector<uint8_t> command)
{
if (isHostForbiddenCommand(command))
{
stdplus::print(stderr,
"The Command is forbidden in SendHostCommand!\n");
return generateErrorResponse(EC_RES_ACCESS_DENIED);
}
return hostCmd->sendCommand(command);
}
std::vector<uint8_t>
Hoth::sendTrustedHostCommand(const std::vector<uint8_t> command)
{
return hostCmd->sendCommand(command);
}
uint64_t Hoth::sendHostCommandAsync(const std::vector<uint8_t> command)
{
return hostCmd->sendCommandAsync(command);
}
std::vector<uint8_t> Hoth::getHostCommandResponse(uint64_t callToken)
{
return hostCmd->getResponse(callToken);
}
void Hoth::checkForOngoingSpiWrite()
{
if (futSpiWrite.valid())
{
std::future_status status =
futSpiWrite.wait_for(std::chrono::seconds(0));
if (status == std::future_status::timeout)
{
stdplus::print(stderr,
"There is already an SPI write in progress\n");
throw FirmwareFailure();
}
}
}
void Hoth::updateFirmware(std::vector<uint8_t> firmwareData)
{
checkForOngoingSpiWrite();
futSpiWrite = std::async(std::launch::async, &FirmwareUpdater::update,
firmwareUpdater, std::move(firmwareData));
}
void Hoth::spiWrite(uint32_t address, std::vector<uint8_t> data)
{
checkForOngoingSpiWrite();
futSpiWrite = std::async(std::launch::async, &FirmwareUpdater::spiWrite,
firmwareUpdater, address, std::move(data));
}
Hoth::FirmwareUpdateStatus Hoth::getFirmwareUpdateStatus()
{
if (!futSpiWrite.valid())
{
// If there is no instance of futFirmwareUpdate, it means an update has
// already finished or that the update has not been started.
// Return the last valid status
return lastFirmwareUpdateStatus;
}
lastFirmwareUpdateStatus = getAsyncStatus(&futSpiWrite);
return lastFirmwareUpdateStatus;
}
void Hoth::initiatePayload()
{
checkForOngoingPayload();
futInitiatePayload =
std::async(std::launch::async, &PayloadUpdate::initiate, payloadUpdate);
}
Hoth::FirmwareUpdateStatus Hoth::getInitiatePayloadStatus()
{
if (!futInitiatePayload.valid())
{
// If there is no instance of futInitiatePayload, it means
// initialization has already finished or that the initialization has
// not been started. Return the last valid status
return lastInitiatePayloadStatus;
}
lastInitiatePayloadStatus = getAsyncStatus(&futInitiatePayload);
return lastInitiatePayloadStatus;
}
void Hoth::sendPayload(std::string imagePath)
{
checkForOngoingPayload();
if (imagePath.empty())
{
stdplus::print(stderr, "No image path was specified\n");
throw FirmwareFailure();
}
futSendPayload = std::async(std::launch::async, &PayloadUpdate::send,
payloadUpdate, std::move(imagePath));
}
Hoth::FirmwareUpdateStatus Hoth::getSendPayloadStatus()
{
if (!futSendPayload.valid())
{
// If there is no instance of futSendPayload, it means send payload has
// already finished or that the transfer has not been started.
// Return the last valid status
return lastSendPayloadStatus;
}
lastSendPayloadStatus = getAsyncStatus(&futSendPayload);
return lastSendPayloadStatus;
}
void Hoth::erasePayload(uint32_t offset, uint32_t size)
{
static constexpr uint32_t sectorSize = 4096;
checkForOngoingPayload();
// Offset and size must be aligned to the sector size
if (offset % sectorSize)
{
stdplus::print(stderr,
"Erase offset must be divisible by the sector size\n");
throw CommandFailure();
}
if (size % sectorSize)
{
stdplus::print(stderr,
"Erase size must be divisible by the sector size\n");
throw CommandFailure();
}
payloadUpdate->erase(offset, size);
}
void Hoth::eraseAndSendStaticWPPayload(std::string imagePath)
{
checkForOngoingPayload();
if (imagePath.empty())
{
stdplus::print(stderr, "No image path was specified\n");
throw FirmwareFailure();
}
futSendPayload =
std::async(std::launch::async, &PayloadUpdate::eraseAndSendStaticWP,
payloadUpdate, std::move(imagePath));
}
void Hoth::verifyPayload()
{
checkForOngoingPayload();
futVerifyPayload =
std::async(std::launch::async, &PayloadUpdate::verify, payloadUpdate);
}
Hoth::FirmwareUpdateStatus Hoth::getVerifyPayloadStatus()
{
if (!futVerifyPayload.valid())
{
// If there is no instance of futVerifyPayload, it means verify payload
// has already finished or that the transfer has not been started.
// Return the last valid status
return lastVerifyPayloadStatus;
}
lastVerifyPayloadStatus = getAsyncStatus(&futVerifyPayload);
return lastVerifyPayloadStatus;
}
void Hoth::activatePayload(bool makePersistent)
{
checkForOngoingPayload();
// Activate the non-active (staging) partition with persistence level
// defined by makePersistent
payload_update_status response = payloadUpdate->getStatus();
payloadUpdate->activate(internal::Side(!response.active_half),
internal::Persistence(makePersistent));
}
void Hoth::deactivatePayload()
{
checkForOngoingPayload();
// Deactivate the non-active (staging) partition with persistence level
// defined by makePersistent by making current active_half as active.
payload_update_status response = payloadUpdate->getStatus();
payloadUpdate->activate(internal::Side(response.active_half),
internal::Persistence(true));
}
uint32_t Hoth::getPayloadSize()
{
uint32_t right = 1;
std::array<uint8_t, 1> buf;
try
{
while (true)
{
payloadUpdate->read(right, buf);
right <<= 1;
}
}
catch (const ResponseFailure&)
{}
// Binary search for the size boundary by checking to see if reading
// a single byte works for an offset. The size should be one offset greater
// than the largest succeeding offset. We maintain the invariant inside the
// loop that the left is always a valid readable offset and the right will
// always fail to read. The loop terminates when the left becomes invalid
// and represents the size.
uint32_t left = right >> 1;
while (left < right)
{
uint32_t mid = (left + right) >> 1;
try
{
payloadUpdate->read(mid, buf);
left = mid + 1;
}
catch (const ResponseFailure&)
{
right = mid;
}
}
if (left == 0)
{
stdplus::print(stderr, "Failed to get flash size from device\n");
throw InterfaceError();
}
return left;
}
void Hoth::confirm()
{
// Get current time to be used for Hoth's cookie
auto now = std::chrono::system_clock::now();
auto now_ms = std::chrono::time_point_cast<std::chrono::milliseconds>(now);
auto cookie = now_ms.time_since_epoch();
payloadUpdate->confirm(payload_update_confirm_option::Confirm, 0,
cookie.count());
}
uint32_t Hoth::getTotalBootTime()
{
ec_response_statistics statistic = ecUtil->getHothStatistics();
if (statistic.valid_words <
(offsetof(decltype(statistic), boot_timing_total) +
sizeof(statistic.boot_timing_total)) /
4)
{
stdplus::print(stderr,
"Timing information is not included in the response\n");
throw ExpectedInfoNotFound();
}
return statistic.boot_timing_total.end_us -
statistic.boot_timing_total.start_us;
}
uint32_t Hoth::getFirmwareUpdateTime()
{
ec_response_statistics statistic = ecUtil->getHothStatistics();
if (statistic.valid_words <
(offsetof(decltype(statistic), boot_timing_firmware_update) +
sizeof(statistic.boot_timing_firmware_update)) /
4)
{
stdplus::print(stderr,
"Timing information is not included in the response\n");
throw ExpectedInfoNotFound();
}
return statistic.boot_timing_firmware_update.end_us -
statistic.boot_timing_firmware_update.start_us;
}
uint32_t Hoth::getFirmwareMirroringTime()
{
ec_response_statistics statistic = ecUtil->getHothStatistics();
if (statistic.valid_words <
(offsetof(decltype(statistic), boot_timing_firmware_mirroring) +
sizeof(statistic.boot_timing_firmware_mirroring)) /
4)
{
stdplus::print(stderr,
"Timing information is not included in the response\n");
throw ExpectedInfoNotFound();
}
return statistic.boot_timing_firmware_mirroring.end_us -
statistic.boot_timing_firmware_mirroring.start_us;
}
uint32_t Hoth::getPayloadValidationTime()
{
ec_response_statistics statistic = ecUtil->getHothStatistics();
if (statistic.valid_words <
(offsetof(decltype(statistic), boot_timing_payload_validation) +
sizeof(statistic.boot_timing_payload_validation)) /
4)
{
stdplus::print(stderr,
"Timing information is not included in the response\n");
throw ExpectedInfoNotFound();
}
return statistic.boot_timing_payload_validation.end_us -
statistic.boot_timing_payload_validation.start_us;
}
void Hoth::collectHothLogs()
{
hostCmd->collectHothLogsAsync(true);
}
void Hoth::collectUartLogs()
{
hostCmd->collectUartLogsAsync();
}
void Hoth::stopUartLogs()
{
hostCmd->stopUartLogs();
}
} // namespace hoth
} // namespace google