// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "config.h"

#include <fcntl.h>
#include <libcr51sign/cr51_image_descriptor.h>
#include <libcr51sign/libcr51sign.h>
#include <libcr51sign/libcr51sign_support.h>
#include <sys/types.h>
#include <unistd.h>

#include <flasher/device.hpp>
#include <flasher/file.hpp>
#include <flasher/mod.hpp>
#include <flasher/mutate.hpp>
#include <flashupdate/info.hpp>
#include <flashupdate/logging.hpp>
#include <flashupdate/validator/cr51.hpp>
#include <stdplus/raw.hpp>

#include <format>
#include <memory>
#include <optional>
#include <span>
#include <stdexcept>
#include <string_view>
#include <utility>
#include <vector>

struct FdState
{
    stdplus::ManagedFd fd;
    std::optional<size_t> offset;
};

flasher::Reader* cr51Reader;
flasher::File* mauvManager;

int ReadFromFd(const void*, uint32_t offset, uint32_t count,
               uint8_t* buf) noexcept
{
    if (cr51Reader == nullptr)
    {
        return LIBCR51SIGN_ERROR_RUNTIME_FAILURE;
    }

    try
    {
        cr51Reader->readAtExact(
            std::span<std::byte>{reinterpret_cast<std::byte*>(buf), count},
            offset);
        return LIBCR51SIGN_SUCCESS;
    }
    catch (const std::exception& e)
    {
        LOG(flashupdate::LogLevel::Error, "Reading {}", e.what());
        return LIBCR51SIGN_ERROR_RUNTIME_FAILURE;
    }
}

int ReadImageMauv(const void*, uint8_t* const mauv, uint32_t* const size,
                  const uint32_t maxSize) noexcept
{
    if (maxSize > IMAGE_MAUV_DATA_MAX_SIZE)
    {
        return LIBCR51SIGN_ERROR_MAX;
    }
    if (mauvManager == nullptr)
    {
        return LIBCR51SIGN_NO_STORED_MAUV_FOUND;
    }

    try
    {
        mauvManager->readAtExact(
            std::span<std::byte>{reinterpret_cast<std::byte*>(mauv),
                                 sizeof(struct image_mauv)},
            CR51_MAUV_OFFSET);
    }
    catch (const std::exception& e)
    {
        LOG(flashupdate::LogLevel::Error, "Reading MAUV {}\n", e.what());
        return LIBCR51SIGN_ERROR_RETRIEVING_STORED_IMAGE_MAUV_DATA;
    }

    uint32_t mauvStructVersion;
    memcpy(&mauvStructVersion, mauv, sizeof(mauvStructVersion));
    if (mauvStructVersion != IMAGE_MAUV_STRUCT_VERSION)
    {
        return LIBCR51SIGN_NO_STORED_MAUV_FOUND;
    }
    struct image_mauv* mauvTmp = reinterpret_cast<struct image_mauv*>(mauv);
    size_t versionDenylistOffset =
        offsetof(struct image_mauv, version_denylist);

    size_t denyListSize =
        std::min(mauvTmp->version_denylist_num_entries * sizeof(uint64_t),
                 IMAGE_MAUV_DATA_MAX_SIZE - sizeof(struct image_mauv));
    try
    {
        mauvManager->readAtExact(
            std::span<std::byte>{
                reinterpret_cast<std::byte*>(mauv + versionDenylistOffset),
                // Deny List size.
                std::min(mauvTmp->version_denylist_num_entries *
                             sizeof(uint64_t),
                         IMAGE_MAUV_DATA_MAX_SIZE - sizeof(struct image_mauv))},
            CR51_MAUV_OFFSET + versionDenylistOffset);
    }
    catch (const std::exception& e)
    {
        LOG(flashupdate::LogLevel::Error, "Reading MAUV {}\n", e.what());
        return LIBCR51SIGN_ERROR_RETRIEVING_STORED_IMAGE_MAUV_DATA;
    }

    *size = std::min(maxSize, static_cast<uint32_t>(
                                  sizeof(struct image_mauv) + denyListSize));
    return LIBCR51SIGN_SUCCESS;
}

int WriteImageMauv(const void*, const uint8_t* const mauv,
                   const uint32_t size) noexcept
{
    if (mauvManager == nullptr)
    {
        return LIBCR51SIGN_NO_STORED_MAUV_FOUND;
    }

    try
    {
        auto data = std::vector<std::byte>(size);
        memcpy(data.data(), mauv, size);
        mauvManager->writeAtExact(data, CR51_MAUV_OFFSET);
    }
    catch (const std::exception& e)
    {
        LOG(flashupdate::LogLevel::Error, "Writing MAUV {}\n", e.what());
        return LIBCR51SIGN_ERROR_STORING_NEW_IMAGE_MAUV_DATA;
    }
    return LIBCR51SIGN_SUCCESS;
}

// TODO: Remove after https://gerrit.openbmc.org/c/47332 is submitted.
namespace google
{
namespace cr51
{

std::span<const uint8_t> Cr51SignValidatorIpml::hashDescriptor(
    struct libcr51sign_ctx* ctx, std::span<std::byte> imageDescriptor)
{
    // Create the HASH of the CR51 descriptor on the firmware
    size_t size = 0;
    hash_type hashType = static_cast<hash_type>(ctx->descriptor.hash_type);

    switch (hashType)
    {
        case HASH_SHA2_224:
        case HASH_SHA3_224:
            size = SHA224_DIGEST_LENGTH;
            break;
        case HASH_SHA2_256:
        case HASH_SHA3_256:
            size = SHA256_DIGEST_LENGTH;
            break;
        case HASH_SHA2_384:
        case HASH_SHA3_384:
            size = SHA384_DIGEST_LENGTH;
            break;
        case HASH_SHA2_512:
        case HASH_SHA3_512:
            size = SHA512_DIGEST_LENGTH;
            break;
        case HASH_NONE:
        default:
            stdplus::println(stderr, "CR51 Hash type is not supported: type {}",
                             static_cast<int>(hashType));
            return {};
            break;
    }

    hash.resize(size);
    int ec = hash_init(ctx, hashType);
    if (ec)
    {
        stdplus::println(stderr, "CR51 Hash init error: ec{}", ec);
        return {};
    }

    ec = hash_update(ctx,
                     reinterpret_cast<const uint8_t*>(imageDescriptor.data()),
                     ctx->descriptor.descriptor_area_size);
    if (ec)
    {
        stdplus::println(stderr, "CR51 Hash update error: ec{}", ec);
        return {};
    }

    ec = hash_final(ctx, hash.data());
    if (ec)
    {
        stdplus::println(stderr, "CR51 Hash final error: ec{}", ec);
        return {};
    }

    return hash;
}

std::optional<struct libcr51sign_validated_regions>
    Cr51SignValidatorIpml::validateDescriptor(struct libcr51sign_ctx* ctx)
{
    // Disabled stderr for all info messages
    fpos_t pos;

    // Save stderr location
    fgetpos(stderr, &pos);
    int fd = dup(fileno(stderr));

    // Redirect them to /dev/null
    if (!freopen("/dev/null", "w", stderr))
    {
        throw std::runtime_error("failed to redirect stderr to /dev/null");
    }

    struct libcr51sign_intf intf = {};
    struct libcr51sign_validated_regions imageRegions;
    enum libcr51sign_validation_failure_reason ec;

    /* Common functions are from the libcr51sign support header. */
    intf.hash_init = &hash_init;
    intf.hash_update = &hash_update;
    intf.hash_final = &hash_final;
    intf.verify_signature = &verify_signature;
    intf.read = &ReadFromFd;
    intf.retrieve_stored_image_mauv_data = &ReadImageMauv;
    intf.store_new_image_mauv_data = &WriteImageMauv;

    auto returnTrue = []() { return true; };
    auto returnFalse = []() { return false; };
    intf.prod_to_dev_downgrade_allowed = prodToDev ? returnTrue : returnFalse;
    intf.is_production_mode = productionMode ? returnTrue : returnFalse;

    // TODO: Validate the non-static regions after all of the regions are signed
    // This only applies to clean image that is not read from the flash
    // directly.
    ec = libcr51sign_validate(ctx, &intf, &imageRegions);

    fflush(stderr);
    dup2(fd, fileno(stderr)); // restore the stderr
    close(fd);
    clearerr(stderr);

    // Move stderr to the normal location
    fsetpos(stderr, &pos);

    if (ec != LIBCR51SIGN_SUCCESS)
    {
        LOG(flashupdate::LogLevel::Notice, "CR51 Validate error: {}",
            libcr51sign_errorcode_to_string(ec))
        return std::nullopt;
    }

    return imageRegions;
}

} // namespace cr51
} // namespace google

namespace flashupdate
{
namespace validator
{
namespace cr51
{

using flasher::ModArgs;
using stdplus::fd::OpenAccess;
using stdplus::fd::OpenFlag;
using stdplus::fd::OpenFlags;

/** @brief Format the Image Version String
 *
 * @param[in] descriptor  Cr51 image descriptor
 *
 * @return version of the image
 */
std::string formatImageVersion(const struct image_descriptor& descriptor)
{
    return std::format("{}.{}.{}.{}", descriptor.image_major,
                       descriptor.image_minor, descriptor.image_point,
                       descriptor.image_subpoint);
}

bool Cr51::validateImage(flasher::Reader& reader,
                         const std::vector<std::string>& keys)
{
    this->keys = keys;
    return verify(reader);
}

std::string Cr51::imageVersion() const
{
    return version;
}

bool Cr51::verify(flasher::Reader& reader)
{
    LOG(LogLevel::Info, "Read CR51 Descriptor with size of {}",
        reader.getSize());
    cr51Reader = &reader;
    valid = false;

    context.start_offset = 0;
    context.end_offset = static_cast<uint32_t>(reader.getSize());
    context.current_image_family = static_cast<image_family>(imageFamily);
    context.current_image_type = IMAGE_PROD;
    context.keyring_len = kSignatureRsa4096Pkcs15KeyLength;
    context.priv = &shaContext;

    struct libcr51sign_validated_regions imageRegions;

    if (keys.empty())
    {
        throw std::runtime_error("no valid validation key available");
    }

    std::unique_ptr<flasher::File> mauv = nullptr;
    if (std::string(CR51_MAUV_PATH) != "")
    {
        auto mauvMod = ModArgs(CR51_MAUV_PATH);
        mauv = flasher::openFile(mauvMod, OpenFlags(OpenAccess::ReadWrite));
        mauvManager = mauv.get();
    }
    else
    {
        mauvManager = nullptr;
    }

    std::optional<libcr51sign_validated_regions> maybeImageRegions;
    std::string usedKeys;
    for (const auto& key : keys)
    {
        context.keyring = key.data();
        maybeImageRegions = cr51Validator.validateDescriptor(&context);
        if (maybeImageRegions)
        {
            LOG(LogLevel::Notice, "CR51 sign is valid using {}", key);
            break;
        }
        usedKeys += key + ",";
    }
    mauvManager = nullptr;

    if (!maybeImageRegions)
    {
        LOG(LogLevel::Crit, "CR51 sign is invalid for using all keys: {}",
            usedKeys);
        return false;
    }
    imageRegions = *maybeImageRegions;

    LOG(LogLevel::Notice, "Passed CR51 Sign Validation");
    valid = true;
    regions = std::vector<struct image_region>(
        &imageRegions.image_regions[0],
        &imageRegions.image_regions[0] + imageRegions.region_count);
    version = formatImageVersion(context.descriptor);
    LOG(LogLevel::Notice, "BIOS Version: {}", version);

    std::vector<std::byte> buf(context.descriptor.descriptor_area_size);
    reader.readAt(buf, context.descriptor.descriptor_offset);

    // Create the HASH of the CR51 descriptor on the BIOS
    auto hashDescriptor = cr51Validator.hashDescriptor(&context, buf);
    if (!hashDescriptor.empty())
    {
        hash =
            std::vector<uint8_t>(hashDescriptor.begin(), hashDescriptor.end());
        descriptorOffset = context.descriptor.descriptor_offset;
        descriptorSize = context.descriptor.descriptor_area_size;
    }
    // Update the sign type for the BIOS image
    this->prod = context.descriptor.image_type == image_type::IMAGE_PROD;
    return true;
}

std::span<ImageRegion> Cr51::persistentRegions()
{
    if (!valid)
        return {};

    persistentRegion.clear();
    for (const auto& region : regions)
    {
        bool persistent =
            (region.region_attributes &
             (IMAGE_REGION_PERSISTENT | IMAGE_REGION_PERSISTENT_RELOCATABLE |
              IMAGE_REGION_PERSISTENT_EXPANDABLE)) > 0;
        auto region_name = reinterpret_cast<const char*>(region.region_name);
        LOG(LogLevel::Debug,
            "Partition Name: {}, Offset: {}, Size: {}, Persistent?: {}",
            region_name, region.region_offset, region.region_size, persistent);
        if (persistent)
        {
            persistentRegion.emplace_back(ImageRegion{
                region_name, region.region_offset, region.region_size});
        };
    }

    return persistentRegion;
}

std::pair<std::string, bool>
    Cr51::validateHash(flasher::ModArgs mod, uint32_t offset, uint32_t size,
                       std::span<const uint8_t> expected)
{
    auto readDev = flasher::openDevice(mod);
    std::vector<std::byte> fileIn(size);
    readDev->readAt(fileIn, offset);

    struct image_descriptor descriptor =
        *reinterpret_cast<struct image_descriptor*>(fileIn.data());

    // Get the descriptor hash type from the target partition
    context.descriptor.hash_type = descriptor.hash_type;
    context.descriptor.descriptor_offset = 0,
    context.descriptor.descriptor_area_size = size;
    context.priv = &shaContext;

    auto hashDescriptor = cr51Validator.hashDescriptor(&context, fileIn);
    stdplus::println(
        stderr, "validateHash: current {} vs. expected {}",
        info::bytesToHex(std::span<const uint8_t>(
            &hashDescriptor[0], &hashDescriptor[0] + SHA256_DIGEST_LENGTH)),
        info::bytesToHex(std::span<const uint8_t>(
            &expected[0], &expected[0] + SHA256_DIGEST_LENGTH)));

    std::string version = formatImageVersion(descriptor);
    bool validHash = hashDescriptor.size() == expected.size() &&
                     std::equal(hashDescriptor.begin(), hashDescriptor.end(),
                                expected.begin(), expected.end());

    return {version, validHash};
}

std::string Cr51::fetchVersion(flasher::ModArgs mod, uint32_t offset,
                               uint32_t size)
{
    return validator::cr51::formatImageVersion(
        imageDescriptor(mod, offset, size));
}

bool Cr51::isProdImage(flasher::ModArgs mod, uint32_t offset, uint32_t size)
{
    return imageDescriptor(mod, offset, size).image_type ==
           image_type::IMAGE_PROD;
}

struct image_descriptor
    Cr51::imageDescriptor(flasher::ModArgs mod, uint32_t offset, uint32_t size)
{
    auto readDev = flasher::openDevice(mod);
    std::vector<std::byte> fileIn(size);
    readDev->readAt(fileIn, offset);
    return stdplus::raw::copyFrom<struct image_descriptor>(fileIn);
}

} // namespace cr51
} // namespace validator
} // namespace flashupdate
