blob: 256fe7316aa421ff5bceb462c3cd01d6078359cb [file] [log] [blame] [edit]
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tlv.h"
#include "types.h"
#include <endian.h>
#include <cstring>
#include <format>
#include <stdexcept>
#include <string>
#include <type_traits>
#include <typeinfo>
namespace debug_token::tlv_decoder
{
// Explicit instantiation of template functions for supported value types
template uint8_t Item::getValue() const;
template uint16_t Item::getValue() const;
template uint32_t Item::getValue() const;
template uint64_t Item::getValue() const;
template std::vector<uint8_t> Item::getValue() const;
template std::vector<uint16_t> Item::getValue() const;
template std::vector<uint32_t> Item::getValue() const;
template std::vector<uint64_t> Item::getValue() const;
Item::Item(const std::span<const uint8_t> input)
{
if (input.size() < sizeof(ItemHeader))
{
throw std::runtime_error(std::format(
"TLV item input data is shorter than the header - size: {}",
input.size()));
}
header = std::make_shared<ItemHeader>();
std::memcpy(header.get(), input.data(), sizeof(ItemHeader));
auto dataLength = static_cast<size_t>(le16toh(header->size));
if (dataLength > input.size() - sizeof(ItemHeader))
{
throw std::runtime_error(
std::format("TLV item input data is too short - type: {}, size: {}",
le16toh(header->type), le16toh(header->size)));
}
data =
std::vector<uint8_t>(input.begin() + sizeof(ItemHeader),
input.begin() + sizeof(ItemHeader) + dataLength);
}
uint16_t Item::getType() const
{
return static_cast<uint16_t>(le16toh(header->type));
}
size_t Item::getTotalSize() const
{
return static_cast<size_t>(sizeof(ItemHeader) + getValueSize());
}
size_t Item::getValueSize() const
{
return static_cast<size_t>(le16toh(header->size));
}
const std::vector<uint8_t>& Item::getRawValue() const
{
return data;
}
template <VectorType T>
T Item::getValue() const
{
if constexpr (std::is_same_v<T, std::vector<uint8_t>>)
{
return data;
}
if (data.size() % sizeof(typename T::value_type) != 0)
{
throw std::runtime_error(std::format(
"TLV item vector data size is not a multiple of the type size"
" - type: {}, expected: {}, actual: {}",
typeid(typename T::value_type).name(),
sizeof(typename T::value_type), data.size()));
}
T result;
result.reserve(data.size() / sizeof(typename T::value_type));
for (size_t i = 0; i < data.size(); i += sizeof(typename T::value_type))
{
std::unique_ptr<typename T::value_type> value =
std::make_unique<typename T::value_type>();
std::memcpy(value.get(), data.data() + i,
sizeof(typename T::value_type));
if (std::is_same_v<T, std::vector<uint16_t>>)
{
result.push_back(le16toh(*value));
}
if (std::is_same_v<T, std::vector<uint32_t>>)
{
result.push_back(le32toh(*value));
}
if (std::is_same_v<T, std::vector<uint64_t>>)
{
result.push_back(le64toh(*value));
}
}
return result;
}
template <std::unsigned_integral T>
T Item::getValue() const
{
if (sizeof(T) != data.size())
{
throw std::runtime_error(
std::format("TLV item data size does not match the type size"
" - type: {}, expected: {}, actual: {}",
typeid(T).name(), sizeof(T), data.size()));
}
if (std::is_same_v<T, uint8_t>)
{
return data[0];
}
std::unique_ptr<T> result = std::make_unique<T>();
std::memcpy(result.get(), data.data(), data.size());
if (std::is_same_v<T, uint16_t>)
{
return le16toh(*result);
}
if (std::is_same_v<T, uint32_t>)
{
return le32toh(*result);
}
if (std::is_same_v<T, uint64_t>)
{
return le64toh(*result);
}
return *result;
}
std::string Item::getTypeName(uint16_t type)
{
switch (type)
{
case types::Common::DeviceType:
return "DeviceType";
case types::Common::ChallengeNonce:
return "ChallengeNonce";
case types::Common::DeviceSerialNumber:
return "DeviceSerialNumber";
case types::Common::DeviceSerialNumberArray:
return "DeviceSerialNumberArray";
case types::Common::FirmwareVersion:
return "FirmwareVersion";
case types::Common::AgentVersion:
return "AgentVersion";
case types::Common::LifecycleState:
return "LifecycleState";
case types::Common::TokenIdentifier:
return "TokenIdentifier";
case types::Common::TokenType:
return "TokenType";
case types::Common::TokenConfig:
return "TokenConfig";
case types::Common::NvidiaSignature:
return "NvidiaSignature";
case types::Common::OemSignature:
return "OemSignature";
case types::Common::InstallationStatus:
return "InstallationStatus";
case types::Common::ProcessingStatus:
return "ProcessingStatus";
case types::Common::SkuInformation:
return "SkuInformation";
case types::Common::NvidiaRatchet:
return "NvidiaRatchet";
case types::Common::OemRatchet:
return "OemRatchet";
case types::Common::ValidityCounter:
return "ValidityCounter";
case types::Common::CertificateChain:
return "CertificateChain";
case types::Common::MeasurementTranscript:
return "MeasurementTranscript";
case types::Common::DeviceId:
return "DeviceId";
case types::Common::TokenTypeSubtypeList:
return "TokenTypeSubtypeList";
case types::Common::Payload:
return "Payload";
case types::Common::LegacyToken:
return "LegacyToken";
case types::GPU::FeatureMask:
return "GPUFeatureMask";
case types::GPU::ChipId:
return "GPUChipId";
case types::NBU::KeypairUUID:
return "NBUKeypairUUID";
case types::NBU::PSID:
return "NBUPSID";
case types::NBU::FileDeviceUnique:
return "NBUFileDeviceUnique";
case types::BMCIRoT::TokenVersion:
return "BMCIRoTTokenVersion";
case types::BMCIRoT::NvidiaSignatureAlgorithm:
return "BMCIRoTNvidiaSignatureAlgorithm";
default:
return std::format("UnknownType(0x{:04X})", type);
}
}
Structure::Structure(const std::vector<uint8_t>& input)
{
decode(input);
}
void Structure::decode(const std::vector<uint8_t>& input)
{
if (input.size() < sizeof(StructureHeader))
{
throw std::runtime_error(std::format(
"TLV structure input data is shorter than the header - size: {}",
input.size()));
}
header = std::make_shared<StructureHeader>();
std::memcpy(header.get(), input.data(), sizeof(StructureHeader));
if (std::memcmp(header->identifier, TLV_IDENTIFIER, 4) != 0)
{
throw std::runtime_error(std::format(
"Invalid TLV identifier - identifier: {:02X}{:02X}{:02X}{:02X}",
header->identifier[0], header->identifier[1], header->identifier[2],
header->identifier[3]));
}
auto expectedSize = static_cast<size_t>(le32toh(header->size)) +
sizeof(StructureHeader);
if (expectedSize != input.size())
{
throw std::runtime_error(std::format(
"Invalid TLV header data size - expected: {}, actual: {}",
expectedSize, input.size()));
}
auto currentItr = input.begin() + sizeof(StructureHeader);
while (currentItr < input.end())
{
try
{
auto element = Item(std::span{currentItr, input.end()});
auto type = element.getType();
auto elementSize = element.getTotalSize();
if (data.contains(type))
{
throw std::runtime_error(
std::format("Duplicate TLV data type - type: {}", type));
}
data.insert({type, std::move(element)});
currentItr += elementSize;
}
catch (const std::runtime_error& e)
{
throw std::runtime_error(
std::format("Invalid TLV data - {}", e.what()));
}
}
if (data.empty())
{
throw std::runtime_error("TLV structure input data is empty");
}
}
std::pair<uint16_t, uint16_t> Structure::getVersion() const
{
return std::make_pair(le16toh(header->versionMajor),
le16toh(header->versionMinor));
}
std::vector<uint16_t> Structure::getTypes() const
{
std::vector<uint16_t> types;
std::transform(
data.begin(), data.end(), std::back_inserter(types),
[](const std::pair<uint16_t, Item>& pair) { return pair.first; });
return types;
}
const Item& Structure::get(uint16_t type) const
{
if (data.contains(type))
{
return data.at(type);
}
throw std::runtime_error(
std::format("TLV item not found - type: {}", type));
}
} // namespace debug_token::tlv_decoder